sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -2
- sglang/bench_serving.py +224 -127
- sglang/compile_deep_gemm.py +3 -0
- sglang/launch_server.py +0 -14
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/falcon_h1.py +12 -58
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +68 -31
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +11 -43
- sglang/srt/disaggregation/decode.py +7 -18
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/nixl/conn.py +55 -23
- sglang/srt/disaggregation/prefill.py +17 -32
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/entrypoints/grpc_request_manager.py +10 -23
- sglang/srt/entrypoints/grpc_server.py +220 -80
- sglang/srt/entrypoints/http_server.py +49 -1
- sglang/srt/entrypoints/openai/protocol.py +159 -31
- sglang/srt/entrypoints/openai/serving_chat.py +13 -71
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +4 -0
- sglang/srt/function_call/function_call_parser.py +8 -6
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
- sglang/srt/layers/attention/attention_registry.py +31 -22
- sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
- sglang/srt/layers/attention/flashattention_backend.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +223 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/triton_backend.py +1 -1
- sglang/srt/layers/logits_processor.py +136 -6
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
- sglang/srt/layers/moe/ep_moe/layer.py +8 -286
- sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/utils.py +7 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/modelopt_quant.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/w4afp8.py +2 -16
- sglang/srt/lora/lora_manager.py +0 -8
- sglang/srt/managers/overlap_utils.py +18 -16
- sglang/srt/managers/schedule_batch.py +119 -90
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +213 -126
- sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
- sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
- sglang/srt/managers/tokenizer_manager.py +270 -53
- sglang/srt/managers/tp_worker.py +39 -28
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +162 -68
- sglang/srt/mem_cache/radix_cache.py +8 -3
- sglang/srt/mem_cache/swa_radix_cache.py +70 -14
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/forward_batch_info.py +4 -18
- sglang/srt/model_executor/model_runner.py +55 -51
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +187 -6
- sglang/srt/model_loader/weight_utils.py +3 -0
- sglang/srt/models/falcon_h1.py +11 -9
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +11 -1
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/utils.py +5 -1
- sglang/srt/sampling/sampling_batch_info.py +11 -9
- sglang/srt/server_args.py +100 -33
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_utils.py +0 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils/common.py +18 -0
- sglang/srt/utils/hf_transformers_utils.py +2 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +40 -0
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +18 -2
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +63 -0
- sglang/test/test_utils.py +32 -11
- sglang/version.py +1 -1
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
sglang/srt/models/mixtral.py
CHANGED
@@ -36,7 +36,6 @@ from sglang.srt.layers.linear import (
|
|
36
36
|
RowParallelLinear,
|
37
37
|
)
|
38
38
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
39
|
-
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
40
39
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
41
40
|
from sglang.srt.layers.moe.topk import TopK
|
42
41
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
@@ -94,8 +93,7 @@ class MixtralMoE(nn.Module):
|
|
94
93
|
renormalize=True,
|
95
94
|
)
|
96
95
|
|
97
|
-
|
98
|
-
self.experts = MoEImpl(
|
96
|
+
self.experts = FusedMoE(
|
99
97
|
num_experts=num_experts,
|
100
98
|
top_k=top_k,
|
101
99
|
layer_id=layer_id,
|
sglang/srt/models/mllama4.py
CHANGED
@@ -2,6 +2,7 @@ import json as json_lib
|
|
2
2
|
import logging
|
3
3
|
import math
|
4
4
|
import os
|
5
|
+
import re
|
5
6
|
from collections.abc import Iterable
|
6
7
|
from typing import List, Optional, Set, Tuple
|
7
8
|
|
@@ -422,6 +423,11 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
422
423
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
423
424
|
}
|
424
425
|
|
426
|
+
# Pattern to match language model layers only (skip vision_model and multi_modal_projector)
|
427
|
+
lora_pattern = re.compile(
|
428
|
+
r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
|
429
|
+
)
|
430
|
+
|
425
431
|
def __init__(
|
426
432
|
self,
|
427
433
|
config: Llama4Config,
|
@@ -555,6 +561,10 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
555
561
|
|
556
562
|
return projected_vision_flat
|
557
563
|
|
564
|
+
def should_apply_lora(self, module_name: str) -> bool:
|
565
|
+
"""Skip vision model and multi_modal_projector for LoRA."""
|
566
|
+
return bool(self.lora_pattern.match(module_name))
|
567
|
+
|
558
568
|
def forward(
|
559
569
|
self,
|
560
570
|
input_ids: torch.Tensor,
|
@@ -700,7 +710,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
700
710
|
"""Handle scale parameter remapping. Returns True if handled."""
|
701
711
|
if "scale" in name and "expert" not in name:
|
702
712
|
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
|
703
|
-
return remapped_name
|
713
|
+
return remapped_name != name
|
704
714
|
return False
|
705
715
|
|
706
716
|
def _handle_stacked_params(
|
@@ -0,0 +1,514 @@
|
|
1
|
+
# Copyright 2023-2025 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/nemotron_h.py
|
15
|
+
|
16
|
+
"""Inference-only NemotronH model."""
|
17
|
+
|
18
|
+
from collections.abc import Iterable
|
19
|
+
from typing import Optional, Union
|
20
|
+
|
21
|
+
import torch
|
22
|
+
from torch import nn
|
23
|
+
|
24
|
+
from sglang.srt.configs import NemotronHConfig
|
25
|
+
from sglang.srt.configs.nemotron_h import ATTENTION, MAMBA, MLP
|
26
|
+
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
27
|
+
from sglang.srt.layers.activation import ReLU2
|
28
|
+
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
|
29
|
+
HybridLinearAttnBackend,
|
30
|
+
Mamba2AttnBackend,
|
31
|
+
)
|
32
|
+
from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
|
33
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
34
|
+
from sglang.srt.layers.linear import (
|
35
|
+
ColumnParallelLinear,
|
36
|
+
QKVParallelLinear,
|
37
|
+
RowParallelLinear,
|
38
|
+
)
|
39
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
40
|
+
from sglang.srt.layers.quantization import QuantizationConfig
|
41
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
42
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
43
|
+
DEFAULT_VOCAB_PADDING_SIZE,
|
44
|
+
ParallelLMHead,
|
45
|
+
VocabParallelEmbedding,
|
46
|
+
)
|
47
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
48
|
+
from sglang.srt.model_loader.weight_utils import (
|
49
|
+
default_weight_loader,
|
50
|
+
maybe_remap_kv_scale_name,
|
51
|
+
)
|
52
|
+
from sglang.srt.utils import add_prefix, make_layers_non_pp
|
53
|
+
from sglang.utils import logger
|
54
|
+
|
55
|
+
|
56
|
+
class NemotronHMLP(nn.Module):
|
57
|
+
def __init__(
|
58
|
+
self,
|
59
|
+
config: NemotronHConfig,
|
60
|
+
layer_idx: int,
|
61
|
+
quant_config: Optional[QuantizationConfig] = None,
|
62
|
+
bias: bool = False,
|
63
|
+
prefix: str = "",
|
64
|
+
) -> None:
|
65
|
+
super().__init__()
|
66
|
+
|
67
|
+
hybrid_override_pattern = config.hybrid_override_pattern
|
68
|
+
mlp_index = hybrid_override_pattern[: layer_idx + 1].count("-") - 1
|
69
|
+
if isinstance(config.intermediate_size, list):
|
70
|
+
if len(config.intermediate_size) == 1:
|
71
|
+
intermediate_size = config.intermediate_size[0]
|
72
|
+
else:
|
73
|
+
intermediate_size = config.intermediate_size[mlp_index]
|
74
|
+
else:
|
75
|
+
intermediate_size = config.intermediate_size
|
76
|
+
|
77
|
+
self.up_proj = ColumnParallelLinear(
|
78
|
+
input_size=config.hidden_size,
|
79
|
+
output_size=intermediate_size,
|
80
|
+
bias=bias,
|
81
|
+
quant_config=quant_config,
|
82
|
+
prefix=f"{prefix}.up_proj",
|
83
|
+
)
|
84
|
+
self.down_proj = RowParallelLinear(
|
85
|
+
input_size=intermediate_size,
|
86
|
+
output_size=config.hidden_size,
|
87
|
+
bias=bias,
|
88
|
+
quant_config=quant_config,
|
89
|
+
prefix=f"{prefix}.down_proj",
|
90
|
+
)
|
91
|
+
self.act_fn = ReLU2()
|
92
|
+
|
93
|
+
def forward(self, x: torch.Tensor):
|
94
|
+
x, _ = self.up_proj(x)
|
95
|
+
x = self.act_fn(x)
|
96
|
+
x, _ = self.down_proj(x)
|
97
|
+
return x
|
98
|
+
|
99
|
+
|
100
|
+
class NemotronHMLPDecoderLayer(nn.Module):
|
101
|
+
def __init__(
|
102
|
+
self,
|
103
|
+
config: NemotronHConfig,
|
104
|
+
layer_idx: int,
|
105
|
+
quant_config: Optional[QuantizationConfig] = None,
|
106
|
+
prefix: str = "",
|
107
|
+
) -> None:
|
108
|
+
super().__init__()
|
109
|
+
self.config = config
|
110
|
+
|
111
|
+
self.mixer = NemotronHMLP(
|
112
|
+
config,
|
113
|
+
quant_config=quant_config,
|
114
|
+
bias=config.mlp_bias,
|
115
|
+
prefix=f"{prefix}.mixer",
|
116
|
+
layer_idx=layer_idx,
|
117
|
+
)
|
118
|
+
|
119
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
120
|
+
|
121
|
+
def forward(
|
122
|
+
self,
|
123
|
+
*,
|
124
|
+
hidden_states: torch.Tensor,
|
125
|
+
residual: Optional[torch.Tensor],
|
126
|
+
forward_batch: ForwardBatch,
|
127
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
128
|
+
if residual is None:
|
129
|
+
residual = hidden_states
|
130
|
+
hidden_states = self.norm(hidden_states)
|
131
|
+
else:
|
132
|
+
hidden_states, residual = self.norm(hidden_states, residual)
|
133
|
+
|
134
|
+
hidden_states = self.mixer.forward(hidden_states)
|
135
|
+
return hidden_states, residual
|
136
|
+
|
137
|
+
|
138
|
+
class NemotronHMambaDecoderLayer(nn.Module):
|
139
|
+
def __init__(
|
140
|
+
self,
|
141
|
+
config: NemotronHConfig,
|
142
|
+
layer_idx: int,
|
143
|
+
quant_config: Optional[QuantizationConfig] = None,
|
144
|
+
prefix: str = "",
|
145
|
+
) -> None:
|
146
|
+
super().__init__()
|
147
|
+
self.config = config
|
148
|
+
self.layer_id = layer_idx
|
149
|
+
self.mixer = MambaMixer2(
|
150
|
+
cache_params=config.mamba2_cache_params,
|
151
|
+
hidden_size=config.hidden_size,
|
152
|
+
use_conv_bias=config.use_conv_bias,
|
153
|
+
use_bias=config.use_bias,
|
154
|
+
n_groups=config.mamba_n_groups,
|
155
|
+
rms_norm_eps=config.rms_norm_eps,
|
156
|
+
activation=config.mamba_hidden_act,
|
157
|
+
quant_config=quant_config,
|
158
|
+
)
|
159
|
+
|
160
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
161
|
+
|
162
|
+
def forward(
|
163
|
+
self,
|
164
|
+
*,
|
165
|
+
hidden_states: torch.Tensor,
|
166
|
+
residual: Optional[torch.Tensor],
|
167
|
+
forward_batch: ForwardBatch,
|
168
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
169
|
+
if residual is None:
|
170
|
+
residual = hidden_states
|
171
|
+
hidden_states = self.norm(hidden_states)
|
172
|
+
else:
|
173
|
+
hidden_states, residual = self.norm(hidden_states, residual)
|
174
|
+
|
175
|
+
output = torch.empty_like(hidden_states)
|
176
|
+
attn_backend = forward_batch.attn_backend
|
177
|
+
assert isinstance(attn_backend, HybridLinearAttnBackend)
|
178
|
+
assert isinstance(attn_backend.linear_attn_backend, Mamba2AttnBackend)
|
179
|
+
attn_backend.linear_attn_backend.forward(
|
180
|
+
mixer=self.mixer,
|
181
|
+
layer_id=self.layer_id,
|
182
|
+
hidden_states=hidden_states,
|
183
|
+
output=output,
|
184
|
+
use_triton_causal_conv=True, # TODO: investigate need of `use_triton_causal_conv`
|
185
|
+
)
|
186
|
+
return output, residual
|
187
|
+
|
188
|
+
|
189
|
+
class NemotronHAttention(nn.Module):
|
190
|
+
def __init__(
|
191
|
+
self,
|
192
|
+
config: NemotronHConfig,
|
193
|
+
layer_idx: int,
|
194
|
+
quant_config: Optional[QuantizationConfig] = None,
|
195
|
+
prefix: str = "",
|
196
|
+
) -> None:
|
197
|
+
super().__init__()
|
198
|
+
self.hidden_size = config.hidden_size
|
199
|
+
tp_size = get_tensor_model_parallel_world_size()
|
200
|
+
self.total_num_heads = config.num_attention_heads
|
201
|
+
assert self.total_num_heads % tp_size == 0
|
202
|
+
self.num_heads = self.total_num_heads // tp_size
|
203
|
+
self.total_num_kv_heads = config.num_key_value_heads
|
204
|
+
if self.total_num_kv_heads >= tp_size:
|
205
|
+
# Number of KV heads is greater than TP size, so we partition
|
206
|
+
# the KV heads across multiple tensor parallel GPUs.
|
207
|
+
assert self.total_num_kv_heads % tp_size == 0
|
208
|
+
else:
|
209
|
+
# Number of KV heads is less than TP size, so we replicate
|
210
|
+
# the KV heads across multiple tensor parallel GPUs.
|
211
|
+
assert tp_size % self.total_num_kv_heads == 0
|
212
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
213
|
+
if hasattr(config, "head_dim") and config.head_dim is not None:
|
214
|
+
self.head_dim = config.head_dim
|
215
|
+
else:
|
216
|
+
self.head_dim = config.hidden_size // self.total_num_heads
|
217
|
+
self.q_size = self.num_heads * self.head_dim
|
218
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
219
|
+
self.scaling = self.head_dim**-0.5
|
220
|
+
|
221
|
+
self.qkv_proj = QKVParallelLinear(
|
222
|
+
config.hidden_size,
|
223
|
+
self.head_dim,
|
224
|
+
self.total_num_heads,
|
225
|
+
self.total_num_kv_heads,
|
226
|
+
bias=False,
|
227
|
+
quant_config=quant_config,
|
228
|
+
prefix=f"{prefix}.qkv_proj",
|
229
|
+
)
|
230
|
+
self.o_proj = RowParallelLinear(
|
231
|
+
self.total_num_heads * self.head_dim,
|
232
|
+
config.hidden_size,
|
233
|
+
bias=False,
|
234
|
+
quant_config=quant_config,
|
235
|
+
prefix=f"{prefix}.o_proj",
|
236
|
+
)
|
237
|
+
|
238
|
+
self.attn = RadixAttention(
|
239
|
+
self.num_heads,
|
240
|
+
self.head_dim,
|
241
|
+
self.scaling,
|
242
|
+
num_kv_heads=self.num_kv_heads,
|
243
|
+
layer_id=layer_idx,
|
244
|
+
quant_config=quant_config,
|
245
|
+
prefix=add_prefix("attn", prefix),
|
246
|
+
)
|
247
|
+
|
248
|
+
def forward(
|
249
|
+
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
250
|
+
) -> torch.Tensor:
|
251
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
252
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
253
|
+
attn_output = self.attn.forward(q, k, v, forward_batch)
|
254
|
+
output, _ = self.o_proj(attn_output)
|
255
|
+
return output
|
256
|
+
|
257
|
+
|
258
|
+
class NemotronHAttentionDecoderLayer(nn.Module):
|
259
|
+
def __init__(
|
260
|
+
self,
|
261
|
+
config: NemotronHConfig,
|
262
|
+
layer_idx: int,
|
263
|
+
quant_config: Optional[QuantizationConfig] = None,
|
264
|
+
prefix: str = "",
|
265
|
+
) -> None:
|
266
|
+
super().__init__()
|
267
|
+
|
268
|
+
self.mixer = NemotronHAttention(
|
269
|
+
config,
|
270
|
+
layer_idx,
|
271
|
+
quant_config,
|
272
|
+
prefix=f"{prefix}.mixer",
|
273
|
+
)
|
274
|
+
|
275
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
276
|
+
|
277
|
+
def forward(
|
278
|
+
self,
|
279
|
+
*,
|
280
|
+
hidden_states: torch.Tensor,
|
281
|
+
residual: Optional[torch.Tensor],
|
282
|
+
forward_batch: ForwardBatch,
|
283
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
284
|
+
if residual is None:
|
285
|
+
residual = hidden_states
|
286
|
+
hidden_states = self.norm(hidden_states)
|
287
|
+
else:
|
288
|
+
hidden_states, residual = self.norm(hidden_states, residual)
|
289
|
+
|
290
|
+
hidden_states = self.mixer.forward(
|
291
|
+
hidden_states=hidden_states, forward_batch=forward_batch
|
292
|
+
)
|
293
|
+
return hidden_states, residual
|
294
|
+
|
295
|
+
|
296
|
+
Layers = (
|
297
|
+
NemotronHAttentionDecoderLayer
|
298
|
+
| NemotronHMLPDecoderLayer
|
299
|
+
| NemotronHMambaDecoderLayer
|
300
|
+
)
|
301
|
+
ALL_DECODER_LAYER_TYPES: dict[str, type[Layers]] = {
|
302
|
+
ATTENTION: NemotronHAttentionDecoderLayer,
|
303
|
+
MLP: NemotronHMLPDecoderLayer,
|
304
|
+
MAMBA: NemotronHMambaDecoderLayer,
|
305
|
+
}
|
306
|
+
|
307
|
+
|
308
|
+
class NemotronHModel(nn.Module):
|
309
|
+
def __init__(
|
310
|
+
self,
|
311
|
+
*,
|
312
|
+
config: NemotronHConfig,
|
313
|
+
quant_config: Optional[QuantizationConfig] = None,
|
314
|
+
prefix: str = "",
|
315
|
+
):
|
316
|
+
super().__init__()
|
317
|
+
|
318
|
+
lora_config = None
|
319
|
+
self.config = config
|
320
|
+
lora_vocab = (
|
321
|
+
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
|
322
|
+
if lora_config
|
323
|
+
else 0
|
324
|
+
)
|
325
|
+
self.vocab_size = config.vocab_size + lora_vocab
|
326
|
+
self.org_vocab_size = config.vocab_size
|
327
|
+
|
328
|
+
self.embed_tokens = VocabParallelEmbedding(
|
329
|
+
self.vocab_size,
|
330
|
+
config.hidden_size,
|
331
|
+
org_num_embeddings=config.vocab_size,
|
332
|
+
)
|
333
|
+
|
334
|
+
def get_layer(idx: int, prefix: str):
|
335
|
+
layer_class = ALL_DECODER_LAYER_TYPES[config.hybrid_override_pattern[idx]]
|
336
|
+
return layer_class(config, idx, quant_config=quant_config, prefix=prefix)
|
337
|
+
|
338
|
+
self.layers = make_layers_non_pp(
|
339
|
+
len(config.hybrid_override_pattern), get_layer, prefix=f"{prefix}.layers"
|
340
|
+
)
|
341
|
+
self.norm_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
342
|
+
|
343
|
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
344
|
+
return self.embed_tokens(input_ids)
|
345
|
+
|
346
|
+
def forward(
|
347
|
+
self,
|
348
|
+
input_ids: torch.Tensor,
|
349
|
+
positions: torch.Tensor,
|
350
|
+
forward_batch: ForwardBatch,
|
351
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
352
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
353
|
+
) -> Union[torch.Tensor, PPProxyTensors]:
|
354
|
+
if get_pp_group().is_first_rank:
|
355
|
+
if inputs_embeds is not None:
|
356
|
+
hidden_states = inputs_embeds
|
357
|
+
else:
|
358
|
+
hidden_states = self.get_input_embeddings(input_ids)
|
359
|
+
residual = None
|
360
|
+
else:
|
361
|
+
assert pp_proxy_tensors is not None
|
362
|
+
hidden_states = pp_proxy_tensors["hidden_states"]
|
363
|
+
residual = pp_proxy_tensors["residual"]
|
364
|
+
|
365
|
+
residual = None
|
366
|
+
for layer in self.layers:
|
367
|
+
if not isinstance(layer, Layers):
|
368
|
+
raise ValueError(f"Unknown layer type: {type(layer)}")
|
369
|
+
hidden_states, residual = layer.forward(
|
370
|
+
hidden_states=hidden_states,
|
371
|
+
residual=residual,
|
372
|
+
forward_batch=forward_batch,
|
373
|
+
)
|
374
|
+
|
375
|
+
if not get_pp_group().is_last_rank:
|
376
|
+
return PPProxyTensors(
|
377
|
+
{"hidden_states": hidden_states, "residual": residual}
|
378
|
+
)
|
379
|
+
hidden_states, _ = self.norm_f(hidden_states, residual)
|
380
|
+
return hidden_states
|
381
|
+
|
382
|
+
|
383
|
+
class NemotronHForCausalLM(nn.Module):
|
384
|
+
remap_prefix = {"backbone": "model"}
|
385
|
+
remap_substr = {"A_log": "A", "embeddings": "embed_tokens"}
|
386
|
+
|
387
|
+
# LoRA specific attributes
|
388
|
+
embedding_modules = {
|
389
|
+
"embed_tokens": "input_embeddings",
|
390
|
+
"lm_head": "output_embeddings",
|
391
|
+
}
|
392
|
+
embedding_padding_modules = ["lm_head"]
|
393
|
+
|
394
|
+
def __init__(
|
395
|
+
self,
|
396
|
+
*,
|
397
|
+
config: NemotronHConfig,
|
398
|
+
quant_config: Optional[QuantizationConfig] = None,
|
399
|
+
prefix: str = "",
|
400
|
+
):
|
401
|
+
super().__init__()
|
402
|
+
lora_config = None
|
403
|
+
self.config = config
|
404
|
+
self.model = self._init_model(
|
405
|
+
config=config, quant_config=quant_config, prefix=prefix
|
406
|
+
)
|
407
|
+
if self.config.tie_word_embeddings:
|
408
|
+
self.lm_head = self.model.embed_tokens
|
409
|
+
else:
|
410
|
+
self.unpadded_vocab_size = config.vocab_size
|
411
|
+
if lora_config:
|
412
|
+
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
413
|
+
self.lm_head = ParallelLMHead(
|
414
|
+
self.unpadded_vocab_size,
|
415
|
+
config.hidden_size,
|
416
|
+
org_num_embeddings=config.vocab_size,
|
417
|
+
padding_size=(
|
418
|
+
DEFAULT_VOCAB_PADDING_SIZE
|
419
|
+
# We need bigger padding if using lora for kernel
|
420
|
+
# compatibility
|
421
|
+
if not lora_config
|
422
|
+
else lora_config.lora_vocab_padding_size
|
423
|
+
),
|
424
|
+
quant_config=quant_config,
|
425
|
+
prefix=add_prefix("lm_head", prefix),
|
426
|
+
)
|
427
|
+
self.logits_processor = LogitsProcessor(config)
|
428
|
+
|
429
|
+
def _init_model(
|
430
|
+
self,
|
431
|
+
config: NemotronHConfig,
|
432
|
+
quant_config: Optional[QuantizationConfig] = None,
|
433
|
+
prefix: str = "",
|
434
|
+
):
|
435
|
+
return NemotronHModel(config=config, quant_config=quant_config, prefix=prefix)
|
436
|
+
|
437
|
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
438
|
+
return self.model.get_input_embeddings(input_ids)
|
439
|
+
|
440
|
+
@torch.no_grad()
|
441
|
+
def forward(
|
442
|
+
self,
|
443
|
+
input_ids: torch.Tensor,
|
444
|
+
positions: torch.Tensor,
|
445
|
+
forward_batch: ForwardBatch,
|
446
|
+
input_embeds: Optional[torch.Tensor] = None,
|
447
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
448
|
+
):
|
449
|
+
hidden_states = self.model.forward(
|
450
|
+
input_ids, positions, forward_batch, pp_proxy_tensors, input_embeds
|
451
|
+
)
|
452
|
+
return self.logits_processor(
|
453
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
454
|
+
)
|
455
|
+
|
456
|
+
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
457
|
+
return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs)
|
458
|
+
|
459
|
+
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
460
|
+
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
461
|
+
|
462
|
+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None:
|
463
|
+
stacked_params_mapping = [
|
464
|
+
# (param_name, shard_name, shard_id)
|
465
|
+
("qkv_proj", "q_proj", "q"),
|
466
|
+
("qkv_proj", "k_proj", "k"),
|
467
|
+
("qkv_proj", "v_proj", "v"),
|
468
|
+
]
|
469
|
+
|
470
|
+
updated_weights = []
|
471
|
+
for name, loaded_weight in weights:
|
472
|
+
for prefix, new_key in self.remap_prefix.items():
|
473
|
+
if name.startswith(prefix):
|
474
|
+
name = name.replace(prefix, new_key)
|
475
|
+
for substr, new_key in self.remap_substr.items():
|
476
|
+
if substr in name:
|
477
|
+
name = name.replace(substr, new_key)
|
478
|
+
updated_weights.append((name, loaded_weight))
|
479
|
+
params_dict = dict(self.named_parameters())
|
480
|
+
|
481
|
+
for name, loaded_weight in updated_weights:
|
482
|
+
if "scale" in name:
|
483
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
484
|
+
if name is None:
|
485
|
+
continue
|
486
|
+
|
487
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
488
|
+
if weight_name not in name:
|
489
|
+
continue
|
490
|
+
name = name.replace(weight_name, param_name)
|
491
|
+
# Skip loading extra bias for GPTQ models.
|
492
|
+
if name.endswith(".bias") and name not in params_dict:
|
493
|
+
continue
|
494
|
+
if name not in params_dict:
|
495
|
+
continue
|
496
|
+
param = params_dict[name]
|
497
|
+
weight_loader = param.weight_loader
|
498
|
+
weight_loader(param, loaded_weight, shard_id)
|
499
|
+
break
|
500
|
+
else:
|
501
|
+
# Skip loading extra bias for GPTQ models.
|
502
|
+
if name.endswith(".bias") and name not in params_dict:
|
503
|
+
continue
|
504
|
+
if name in params_dict.keys():
|
505
|
+
param = params_dict[name]
|
506
|
+
weight_loader = getattr(
|
507
|
+
param, "weight_loader", default_weight_loader
|
508
|
+
)
|
509
|
+
weight_loader(param, loaded_weight)
|
510
|
+
else:
|
511
|
+
logger.warning(f"Parameter {name} not found in params_dict")
|
512
|
+
|
513
|
+
|
514
|
+
EntryClass = [NemotronHForCausalLM]
|
sglang/srt/models/utils.py
CHANGED
@@ -27,7 +27,11 @@ if _is_cuda:
|
|
27
27
|
|
28
28
|
def enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
|
29
29
|
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
|
30
|
-
return
|
30
|
+
return (
|
31
|
+
_is_cuda
|
32
|
+
and hasattr(forward_batch.token_to_kv_pool, "dtype")
|
33
|
+
and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
|
34
|
+
)
|
31
35
|
|
32
36
|
|
33
37
|
def create_fused_set_kv_buffer_arg(
|
@@ -44,12 +44,9 @@ class SamplingBatchInfo:
|
|
44
44
|
vocab_mask: Optional[torch.Tensor] = None
|
45
45
|
apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
46
46
|
|
47
|
-
# An event used for overlap schedule
|
48
|
-
sampling_info_done: Optional[threading.Event] = None
|
49
|
-
|
50
47
|
# Penalizer
|
51
48
|
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
52
|
-
|
49
|
+
acc_linear_penalties: torch.Tensor = None # Used in the overlap mode
|
53
50
|
|
54
51
|
# Whether any request has custom logit processor
|
55
52
|
has_custom_logit_processor: bool = False
|
@@ -217,19 +214,19 @@ class SamplingBatchInfo:
|
|
217
214
|
|
218
215
|
def update_penalties(self):
|
219
216
|
if self.penalizer_orchestrator.is_required:
|
220
|
-
self.
|
217
|
+
self.acc_linear_penalties = torch.zeros(
|
221
218
|
(len(self.temperatures), self.vocab_size),
|
222
219
|
dtype=torch.float32,
|
223
220
|
device=self.temperatures.device,
|
224
221
|
)
|
225
|
-
self.penalizer_orchestrator.apply(self.
|
222
|
+
self.penalizer_orchestrator.apply(self.acc_linear_penalties)
|
226
223
|
else:
|
227
|
-
self.
|
224
|
+
self.acc_linear_penalties = None
|
228
225
|
|
229
226
|
def apply_logits_bias(self, logits: torch.Tensor):
|
230
|
-
if self.
|
227
|
+
if self.acc_linear_penalties is not None:
|
231
228
|
# Used in the overlap mode
|
232
|
-
logits.add_(self.
|
229
|
+
logits.add_(self.acc_linear_penalties)
|
233
230
|
|
234
231
|
if self.penalizer_orchestrator and self.penalizer_orchestrator.is_required:
|
235
232
|
# Used in the non-overlap mode
|
@@ -370,6 +367,11 @@ class SamplingBatchInfo:
|
|
370
367
|
self.need_top_k_sampling |= other.need_top_k_sampling
|
371
368
|
self.need_min_p_sampling |= other.need_min_p_sampling
|
372
369
|
|
370
|
+
def copy_for_forward(self):
|
371
|
+
# Accumulate the penalty into a pre-allocated buffer to get rid of the dependency of `penalizer_orchestrator` later
|
372
|
+
self.update_penalties()
|
373
|
+
return dataclasses.replace(self, penalizer_orchestrator=None)
|
374
|
+
|
373
375
|
|
374
376
|
def merge_bias_tensor(
|
375
377
|
lhs: Optional[torch.Tensor],
|