sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +7 -0
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +16 -1
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mooncake/conn.py +16 -0
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/http_server.py +13 -1
- sglang/srt/entrypoints/openai/protocol.py +3 -1
- sglang/srt/entrypoints/openai/serving_base.py +5 -2
- sglang/srt/entrypoints/openai/serving_chat.py +132 -79
- sglang/srt/function_call/ebnf_composer.py +10 -3
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/qwen3_coder_detector.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +14 -3
- sglang/srt/layers/moe/ep_moe/layer.py +323 -242
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
- sglang/srt/layers/moe/topk.py +90 -24
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +27 -10
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/lora/lora_registry.py +93 -29
- sglang/srt/managers/cache_controller.py +9 -7
- sglang/srt/managers/data_parallel_controller.py +4 -0
- sglang/srt/managers/io_struct.py +12 -0
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +14 -8
- sglang/srt/managers/scheduler.py +64 -1
- sglang/srt/managers/scheduler_input_blocker.py +106 -0
- sglang/srt/managers/tokenizer_manager.py +80 -15
- sglang/srt/managers/tp_worker.py +8 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -2
- sglang/srt/model_executor/model_runner.py +83 -27
- sglang/srt/models/deepseek_v2.py +75 -84
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/qwen2_moe.py +2 -2
- sglang/srt/models/qwen3_moe.py +17 -71
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/poll_based_barrier.py +31 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +65 -6
- sglang/srt/two_batch_overlap.py +8 -3
- sglang/srt/utils.py +96 -1
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_utils.py +118 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/METADATA +5 -4
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/RECORD +97 -80
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,100 @@
|
|
1
|
+
from typing import TYPE_CHECKING, Optional, Union
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
6
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
7
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
8
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
9
|
+
|
10
|
+
|
11
|
+
class HybridAttnBackend(AttentionBackend):
|
12
|
+
"""Support different backends for prefill and decode."""
|
13
|
+
|
14
|
+
def __init__(
|
15
|
+
self, prefill_backend: AttentionBackend, decode_backend: AttentionBackend
|
16
|
+
):
|
17
|
+
self.prefill_backend = prefill_backend
|
18
|
+
self.decode_backend = decode_backend
|
19
|
+
|
20
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
21
|
+
if forward_batch.forward_mode.is_decode():
|
22
|
+
self.decode_backend.init_forward_metadata(forward_batch)
|
23
|
+
else:
|
24
|
+
self.prefill_backend.init_forward_metadata(forward_batch)
|
25
|
+
|
26
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
27
|
+
self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
|
28
|
+
|
29
|
+
def init_forward_metadata_capture_cuda_graph(
|
30
|
+
self,
|
31
|
+
bs: int,
|
32
|
+
num_tokens: int,
|
33
|
+
req_pool_indices: torch.Tensor,
|
34
|
+
seq_lens: torch.Tensor,
|
35
|
+
encoder_lens: Optional[torch.Tensor],
|
36
|
+
forward_mode: ForwardMode,
|
37
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
38
|
+
):
|
39
|
+
self.decode_backend.init_forward_metadata_capture_cuda_graph(
|
40
|
+
bs,
|
41
|
+
num_tokens,
|
42
|
+
req_pool_indices,
|
43
|
+
seq_lens,
|
44
|
+
encoder_lens,
|
45
|
+
forward_mode,
|
46
|
+
spec_info,
|
47
|
+
)
|
48
|
+
|
49
|
+
def init_forward_metadata_replay_cuda_graph(
|
50
|
+
self,
|
51
|
+
bs: int,
|
52
|
+
req_pool_indices: torch.Tensor,
|
53
|
+
seq_lens: torch.Tensor,
|
54
|
+
seq_lens_sum: int,
|
55
|
+
encoder_lens: Optional[torch.Tensor],
|
56
|
+
forward_mode: ForwardMode,
|
57
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
58
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
59
|
+
):
|
60
|
+
self.decode_backend.init_forward_metadata_replay_cuda_graph(
|
61
|
+
bs,
|
62
|
+
req_pool_indices,
|
63
|
+
seq_lens,
|
64
|
+
seq_lens_sum,
|
65
|
+
encoder_lens,
|
66
|
+
forward_mode,
|
67
|
+
spec_info,
|
68
|
+
seq_lens_cpu,
|
69
|
+
)
|
70
|
+
|
71
|
+
def get_cuda_graph_seq_len_fill_value(self):
|
72
|
+
return self.decode_backend.get_cuda_graph_seq_len_fill_value()
|
73
|
+
|
74
|
+
def forward_decode(
|
75
|
+
self,
|
76
|
+
q: torch.Tensor,
|
77
|
+
k: torch.Tensor,
|
78
|
+
v: torch.Tensor,
|
79
|
+
layer: RadixAttention,
|
80
|
+
forward_batch: ForwardBatch,
|
81
|
+
save_kv_cache: bool = True,
|
82
|
+
**kwargs,
|
83
|
+
):
|
84
|
+
return self.decode_backend.forward_decode(
|
85
|
+
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
86
|
+
)
|
87
|
+
|
88
|
+
def forward_extend(
|
89
|
+
self,
|
90
|
+
q: torch.Tensor,
|
91
|
+
k: torch.Tensor,
|
92
|
+
v: torch.Tensor,
|
93
|
+
layer: RadixAttention,
|
94
|
+
forward_batch: ForwardBatch,
|
95
|
+
save_kv_cache: bool = True,
|
96
|
+
**kwargs,
|
97
|
+
):
|
98
|
+
return self.prefill_backend.forward_extend(
|
99
|
+
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
100
|
+
)
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3
3
|
import dataclasses
|
4
4
|
import functools
|
5
5
|
import math
|
6
|
-
from functools import lru_cache
|
6
|
+
from functools import lru_cache, partial
|
7
7
|
from typing import Any, Optional, Tuple, Union
|
8
8
|
|
9
9
|
import torch
|
@@ -18,11 +18,16 @@ _is_cuda = is_cuda()
|
|
18
18
|
if _is_cuda:
|
19
19
|
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
20
20
|
|
21
|
-
from sglang.srt.distributed import
|
21
|
+
from sglang.srt.distributed import (
|
22
|
+
parallel_state,
|
23
|
+
split_tensor_along_last_dim,
|
24
|
+
tensor_model_parallel_all_gather,
|
25
|
+
)
|
22
26
|
from sglang.srt.distributed import utils as dist_utils
|
23
27
|
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
24
28
|
context_attention_fwd,
|
25
29
|
)
|
30
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
26
31
|
from sglang.srt.layers.linear import (
|
27
32
|
ColumnParallelLinear,
|
28
33
|
QKVParallelLinear,
|
@@ -349,25 +354,44 @@ class VisionAttention(nn.Module):
|
|
349
354
|
flatten_batch: bool = False,
|
350
355
|
prefix: str = "",
|
351
356
|
proj_bias: bool = True,
|
357
|
+
num_dummy_heads: int = 0,
|
358
|
+
qkv_bias: bool = True,
|
359
|
+
qk_normalization: bool = False,
|
360
|
+
layer_norm_eps: float = 1e-06,
|
352
361
|
**kwargs,
|
353
362
|
):
|
354
363
|
super().__init__()
|
355
364
|
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
365
|
+
self.tp_size = world_size
|
366
|
+
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
356
367
|
self.dropout = dropout
|
357
368
|
self.head_size = embed_dim // num_heads
|
358
369
|
self.hidden_size_per_attention_head = dist_utils.divide(
|
359
370
|
projection_size, num_heads
|
360
371
|
)
|
361
372
|
self.num_attention_heads_per_partition = dist_utils.divide(
|
362
|
-
num_heads, world_size
|
373
|
+
num_dummy_heads + num_heads, world_size
|
363
374
|
)
|
364
375
|
self.num_attention_kv_heads_per_partition = dist_utils.divide(
|
365
|
-
num_heads, world_size
|
376
|
+
num_dummy_heads + num_heads, world_size
|
366
377
|
)
|
367
378
|
|
368
379
|
self.q_size = self.num_attention_heads_per_partition * self.head_size
|
369
380
|
self.kv_size = self.num_attention_kv_heads_per_partition * self.head_size
|
370
381
|
|
382
|
+
self.qk_normalization = qk_normalization
|
383
|
+
|
384
|
+
# Additional dummy heads are used to enable TP for common GPU counts.
|
385
|
+
self.dummy_dim = (num_dummy_heads + num_heads) * self.head_size
|
386
|
+
|
387
|
+
if self.qk_normalization:
|
388
|
+
self.q_norm = RMSNorm(
|
389
|
+
self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
|
390
|
+
)
|
391
|
+
self.k_norm = RMSNorm(
|
392
|
+
self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
|
393
|
+
)
|
394
|
+
|
371
395
|
if global_server_args_dict["mm_attention_backend"] is None:
|
372
396
|
if qkv_backend is None:
|
373
397
|
qkv_backend = "sdpa"
|
@@ -391,26 +415,46 @@ class VisionAttention(nn.Module):
|
|
391
415
|
self.qkv_proj = QKVParallelLinear(
|
392
416
|
hidden_size=embed_dim,
|
393
417
|
head_size=self.head_size,
|
394
|
-
total_num_heads=num_heads,
|
395
|
-
total_num_kv_heads=num_heads,
|
418
|
+
total_num_heads=num_dummy_heads + num_heads,
|
419
|
+
total_num_kv_heads=num_dummy_heads + num_heads,
|
420
|
+
bias=qkv_bias,
|
396
421
|
quant_config=quant_config,
|
397
422
|
prefix=add_prefix("qkv_proj", prefix),
|
398
423
|
)
|
399
424
|
else:
|
400
425
|
self.qkv_proj = ColumnParallelLinear(
|
401
426
|
input_size=embed_dim,
|
402
|
-
output_size=3 *
|
427
|
+
output_size=3 * self.dummy_dim,
|
428
|
+
bias=qkv_bias,
|
403
429
|
quant_config=quant_config,
|
404
430
|
prefix=add_prefix("qkv_proj", prefix),
|
405
431
|
)
|
406
432
|
self.proj = RowParallelLinear(
|
407
|
-
input_size=
|
433
|
+
input_size=self.dummy_dim,
|
408
434
|
output_size=embed_dim,
|
409
435
|
bias=proj_bias,
|
410
436
|
quant_config=quant_config,
|
411
437
|
prefix=add_prefix("proj", prefix),
|
412
438
|
)
|
413
439
|
|
440
|
+
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
|
441
|
+
"""apply qk norm for internvl vit attn"""
|
442
|
+
q = q.flatten(1, 2)
|
443
|
+
k = k.flatten(1, 2)
|
444
|
+
|
445
|
+
if self.tp_size > 1:
|
446
|
+
q = tensor_model_parallel_all_gather(q.contiguous())
|
447
|
+
k = tensor_model_parallel_all_gather(k.contiguous())
|
448
|
+
q = self.q_norm(q)
|
449
|
+
k = self.k_norm(k)
|
450
|
+
if self.tp_size > 1:
|
451
|
+
splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size)
|
452
|
+
q = splitter(q)[self.tp_rank]
|
453
|
+
k = splitter(k)[self.tp_rank]
|
454
|
+
q = q.unflatten(-1, (-1, self.head_size))
|
455
|
+
k = k.unflatten(-1, (-1, self.head_size))
|
456
|
+
return q, k
|
457
|
+
|
414
458
|
def forward(
|
415
459
|
self,
|
416
460
|
x: torch.Tensor,
|
@@ -489,6 +533,10 @@ class VisionAttention(nn.Module):
|
|
489
533
|
assert k.dim() == 3, k.dim()
|
490
534
|
assert v.dim() == 3, v.dim()
|
491
535
|
|
536
|
+
# internvl
|
537
|
+
if self.qk_normalization:
|
538
|
+
q, k = self._apply_qk_norm(q, k)
|
539
|
+
|
492
540
|
output = self.qkv_backend.forward(
|
493
541
|
q=q,
|
494
542
|
k=k,
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -61,10 +61,15 @@ class RMSNorm(CustomOp):
|
|
61
61
|
self,
|
62
62
|
hidden_size: int,
|
63
63
|
eps: float = 1e-6,
|
64
|
+
var_hidden_size: Optional[int] = None,
|
64
65
|
) -> None:
|
65
66
|
super().__init__()
|
66
67
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
67
68
|
self.variance_epsilon = eps
|
69
|
+
self.hidden_size = hidden_size
|
70
|
+
self.variance_size_override = (
|
71
|
+
None if var_hidden_size == hidden_size else var_hidden_size
|
72
|
+
)
|
68
73
|
if _use_aiter:
|
69
74
|
self._forward_method = self.forward_aiter
|
70
75
|
|
@@ -73,6 +78,8 @@ class RMSNorm(CustomOp):
|
|
73
78
|
x: torch.Tensor,
|
74
79
|
residual: Optional[torch.Tensor] = None,
|
75
80
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
81
|
+
if self.variance_size_override is not None:
|
82
|
+
return self.forward_native(x, residual)
|
76
83
|
if residual is not None:
|
77
84
|
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
|
78
85
|
return x, residual
|
@@ -138,7 +145,25 @@ class RMSNorm(CustomOp):
|
|
138
145
|
x = x + residual.to(torch.float32)
|
139
146
|
residual = x.to(orig_dtype)
|
140
147
|
|
141
|
-
|
148
|
+
hidden_size = x.shape[-1]
|
149
|
+
if hidden_size != self.hidden_size:
|
150
|
+
raise ValueError(
|
151
|
+
"Expected hidden_size to be "
|
152
|
+
f"{self.hidden_size}, but found: {hidden_size}"
|
153
|
+
)
|
154
|
+
|
155
|
+
if self.variance_size_override is None:
|
156
|
+
x_var = x
|
157
|
+
else:
|
158
|
+
if hidden_size < self.variance_size_override:
|
159
|
+
raise ValueError(
|
160
|
+
"Expected hidden_size to be at least "
|
161
|
+
f"{self.variance_size_override}, but found: {hidden_size}"
|
162
|
+
)
|
163
|
+
|
164
|
+
x_var = x[..., : self.variance_size_override]
|
165
|
+
|
166
|
+
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
|
142
167
|
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
143
168
|
x = (x * self.weight).to(orig_dtype)
|
144
169
|
if residual is None:
|
@@ -170,8 +170,6 @@ class LogitsMetadata:
|
|
170
170
|
)
|
171
171
|
|
172
172
|
def compute_dp_attention_metadata(self):
|
173
|
-
# TODO(ch-wan): gathered_buffer here is larger than the actual required size in draft extend,
|
174
|
-
# we may use a smaller buffer in draft extend.
|
175
173
|
|
176
174
|
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
|
177
175
|
dp_rank = get_attention_dp_rank()
|
@@ -186,6 +184,19 @@ class LogitsMetadata:
|
|
186
184
|
self.dp_local_start_pos = dp_local_start_pos
|
187
185
|
self.dp_local_num_tokens = dp_local_num_tokens
|
188
186
|
|
187
|
+
if self.global_num_tokens_for_logprob_cpu is not None:
|
188
|
+
# create a smaller buffer to reduce peak memory usage
|
189
|
+
self.gathered_buffer = torch.empty(
|
190
|
+
(
|
191
|
+
sum(self.global_num_tokens_for_logprob_cpu),
|
192
|
+
self.gathered_buffer.shape[1],
|
193
|
+
),
|
194
|
+
dtype=self.gathered_buffer.dtype,
|
195
|
+
device=self.gathered_buffer.device,
|
196
|
+
)
|
197
|
+
else:
|
198
|
+
self.gathered_buffer = torch.empty_like(self.gathered_buffer)
|
199
|
+
|
189
200
|
|
190
201
|
class LogitsProcessor(nn.Module):
|
191
202
|
def __init__(
|
@@ -430,7 +441,7 @@ class LogitsProcessor(nn.Module):
|
|
430
441
|
if self.do_tensor_parallel_all_gather_dp_attn:
|
431
442
|
logits_metadata.compute_dp_attention_metadata()
|
432
443
|
hidden_states, local_hidden_states = (
|
433
|
-
|
444
|
+
logits_metadata.gathered_buffer,
|
434
445
|
hidden_states,
|
435
446
|
)
|
436
447
|
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|