sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. sglang/bench_one_batch.py +0 -2
  2. sglang/bench_serving.py +224 -127
  3. sglang/compile_deep_gemm.py +3 -0
  4. sglang/launch_server.py +0 -14
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/falcon_h1.py +12 -58
  7. sglang/srt/configs/mamba_utils.py +117 -0
  8. sglang/srt/configs/model_config.py +68 -31
  9. sglang/srt/configs/nemotron_h.py +286 -0
  10. sglang/srt/configs/qwen3_next.py +11 -43
  11. sglang/srt/disaggregation/decode.py +7 -18
  12. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  13. sglang/srt/disaggregation/nixl/conn.py +55 -23
  14. sglang/srt/disaggregation/prefill.py +17 -32
  15. sglang/srt/entrypoints/engine.py +2 -2
  16. sglang/srt/entrypoints/grpc_request_manager.py +10 -23
  17. sglang/srt/entrypoints/grpc_server.py +220 -80
  18. sglang/srt/entrypoints/http_server.py +49 -1
  19. sglang/srt/entrypoints/openai/protocol.py +159 -31
  20. sglang/srt/entrypoints/openai/serving_chat.py +13 -71
  21. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  22. sglang/srt/environ.py +4 -0
  23. sglang/srt/function_call/function_call_parser.py +8 -6
  24. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  25. sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
  26. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
  27. sglang/srt/layers/attention/attention_registry.py +31 -22
  28. sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
  29. sglang/srt/layers/attention/flashattention_backend.py +0 -1
  30. sglang/srt/layers/attention/flashinfer_backend.py +223 -6
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
  32. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
  33. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  34. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
  35. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  36. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  37. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  38. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  39. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  40. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  41. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  42. sglang/srt/layers/attention/triton_backend.py +1 -1
  43. sglang/srt/layers/logits_processor.py +136 -6
  44. sglang/srt/layers/modelopt_utils.py +11 -0
  45. sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
  46. sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
  47. sglang/srt/layers/moe/ep_moe/layer.py +8 -286
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
  49. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  50. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  51. sglang/srt/layers/moe/utils.py +7 -1
  52. sglang/srt/layers/quantization/__init__.py +1 -1
  53. sglang/srt/layers/quantization/fp8.py +84 -18
  54. sglang/srt/layers/quantization/modelopt_quant.py +1 -1
  55. sglang/srt/layers/quantization/quark/quark.py +3 -1
  56. sglang/srt/layers/quantization/w4afp8.py +2 -16
  57. sglang/srt/lora/lora_manager.py +0 -8
  58. sglang/srt/managers/overlap_utils.py +18 -16
  59. sglang/srt/managers/schedule_batch.py +119 -90
  60. sglang/srt/managers/schedule_policy.py +1 -1
  61. sglang/srt/managers/scheduler.py +213 -126
  62. sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
  63. sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
  64. sglang/srt/managers/tokenizer_manager.py +270 -53
  65. sglang/srt/managers/tp_worker.py +39 -28
  66. sglang/srt/mem_cache/allocator.py +7 -2
  67. sglang/srt/mem_cache/chunk_cache.py +1 -1
  68. sglang/srt/mem_cache/memory_pool.py +162 -68
  69. sglang/srt/mem_cache/radix_cache.py +8 -3
  70. sglang/srt/mem_cache/swa_radix_cache.py +70 -14
  71. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  72. sglang/srt/model_executor/forward_batch_info.py +4 -18
  73. sglang/srt/model_executor/model_runner.py +55 -51
  74. sglang/srt/model_loader/__init__.py +1 -1
  75. sglang/srt/model_loader/loader.py +187 -6
  76. sglang/srt/model_loader/weight_utils.py +3 -0
  77. sglang/srt/models/falcon_h1.py +11 -9
  78. sglang/srt/models/gemma3_mm.py +16 -0
  79. sglang/srt/models/grok.py +5 -13
  80. sglang/srt/models/mixtral.py +1 -3
  81. sglang/srt/models/mllama4.py +11 -1
  82. sglang/srt/models/nemotron_h.py +514 -0
  83. sglang/srt/models/utils.py +5 -1
  84. sglang/srt/sampling/sampling_batch_info.py +11 -9
  85. sglang/srt/server_args.py +100 -33
  86. sglang/srt/speculative/eagle_worker.py +11 -13
  87. sglang/srt/speculative/ngram_worker.py +12 -11
  88. sglang/srt/speculative/spec_utils.py +0 -1
  89. sglang/srt/two_batch_overlap.py +1 -0
  90. sglang/srt/utils/common.py +18 -0
  91. sglang/srt/utils/hf_transformers_utils.py +2 -0
  92. sglang/test/longbench_v2/__init__.py +1 -0
  93. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  94. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  95. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  96. sglang/test/run_eval.py +40 -0
  97. sglang/test/simple_eval_longbench_v2.py +332 -0
  98. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  99. sglang/test/test_deterministic.py +18 -2
  100. sglang/test/test_deterministic_utils.py +81 -0
  101. sglang/test/test_disaggregation_utils.py +63 -0
  102. sglang/test/test_utils.py +32 -11
  103. sglang/version.py +1 -1
  104. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
  105. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
  106. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  107. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  108. sglang/test/test_block_fp8_ep.py +0 -358
  109. /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
  110. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  111. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  112. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -36,7 +36,6 @@ from sglang.srt.layers.linear import (
36
36
  RowParallelLinear,
37
37
  )
38
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
40
39
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
41
40
  from sglang.srt.layers.moe.topk import TopK
42
41
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -94,8 +93,7 @@ class MixtralMoE(nn.Module):
94
93
  renormalize=True,
95
94
  )
96
95
 
97
- MoEImpl = EPMoE if get_moe_expert_parallel_world_size() > 1 else FusedMoE
98
- self.experts = MoEImpl(
96
+ self.experts = FusedMoE(
99
97
  num_experts=num_experts,
100
98
  top_k=top_k,
101
99
  layer_id=layer_id,
@@ -2,6 +2,7 @@ import json as json_lib
2
2
  import logging
3
3
  import math
4
4
  import os
5
+ import re
5
6
  from collections.abc import Iterable
6
7
  from typing import List, Optional, Set, Tuple
7
8
 
@@ -422,6 +423,11 @@ class Llama4ForConditionalGeneration(nn.Module):
422
423
  "gate_up_proj": ["gate_proj", "up_proj"],
423
424
  }
424
425
 
426
+ # Pattern to match language model layers only (skip vision_model and multi_modal_projector)
427
+ lora_pattern = re.compile(
428
+ r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
429
+ )
430
+
425
431
  def __init__(
426
432
  self,
427
433
  config: Llama4Config,
@@ -555,6 +561,10 @@ class Llama4ForConditionalGeneration(nn.Module):
555
561
 
556
562
  return projected_vision_flat
557
563
 
564
+ def should_apply_lora(self, module_name: str) -> bool:
565
+ """Skip vision model and multi_modal_projector for LoRA."""
566
+ return bool(self.lora_pattern.match(module_name))
567
+
558
568
  def forward(
559
569
  self,
560
570
  input_ids: torch.Tensor,
@@ -700,7 +710,7 @@ class Llama4ForConditionalGeneration(nn.Module):
700
710
  """Handle scale parameter remapping. Returns True if handled."""
701
711
  if "scale" in name and "expert" not in name:
702
712
  remapped_name = maybe_remap_kv_scale_name(name, params_dict)
703
- return remapped_name is not None and remapped_name != name
713
+ return remapped_name != name
704
714
  return False
705
715
 
706
716
  def _handle_stacked_params(
@@ -0,0 +1,514 @@
1
+ # Copyright 2023-2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/nemotron_h.py
15
+
16
+ """Inference-only NemotronH model."""
17
+
18
+ from collections.abc import Iterable
19
+ from typing import Optional, Union
20
+
21
+ import torch
22
+ from torch import nn
23
+
24
+ from sglang.srt.configs import NemotronHConfig
25
+ from sglang.srt.configs.nemotron_h import ATTENTION, MAMBA, MLP
26
+ from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
27
+ from sglang.srt.layers.activation import ReLU2
28
+ from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
29
+ HybridLinearAttnBackend,
30
+ Mamba2AttnBackend,
31
+ )
32
+ from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
33
+ from sglang.srt.layers.layernorm import RMSNorm
34
+ from sglang.srt.layers.linear import (
35
+ ColumnParallelLinear,
36
+ QKVParallelLinear,
37
+ RowParallelLinear,
38
+ )
39
+ from sglang.srt.layers.logits_processor import LogitsProcessor
40
+ from sglang.srt.layers.quantization import QuantizationConfig
41
+ from sglang.srt.layers.radix_attention import RadixAttention
42
+ from sglang.srt.layers.vocab_parallel_embedding import (
43
+ DEFAULT_VOCAB_PADDING_SIZE,
44
+ ParallelLMHead,
45
+ VocabParallelEmbedding,
46
+ )
47
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
48
+ from sglang.srt.model_loader.weight_utils import (
49
+ default_weight_loader,
50
+ maybe_remap_kv_scale_name,
51
+ )
52
+ from sglang.srt.utils import add_prefix, make_layers_non_pp
53
+ from sglang.utils import logger
54
+
55
+
56
+ class NemotronHMLP(nn.Module):
57
+ def __init__(
58
+ self,
59
+ config: NemotronHConfig,
60
+ layer_idx: int,
61
+ quant_config: Optional[QuantizationConfig] = None,
62
+ bias: bool = False,
63
+ prefix: str = "",
64
+ ) -> None:
65
+ super().__init__()
66
+
67
+ hybrid_override_pattern = config.hybrid_override_pattern
68
+ mlp_index = hybrid_override_pattern[: layer_idx + 1].count("-") - 1
69
+ if isinstance(config.intermediate_size, list):
70
+ if len(config.intermediate_size) == 1:
71
+ intermediate_size = config.intermediate_size[0]
72
+ else:
73
+ intermediate_size = config.intermediate_size[mlp_index]
74
+ else:
75
+ intermediate_size = config.intermediate_size
76
+
77
+ self.up_proj = ColumnParallelLinear(
78
+ input_size=config.hidden_size,
79
+ output_size=intermediate_size,
80
+ bias=bias,
81
+ quant_config=quant_config,
82
+ prefix=f"{prefix}.up_proj",
83
+ )
84
+ self.down_proj = RowParallelLinear(
85
+ input_size=intermediate_size,
86
+ output_size=config.hidden_size,
87
+ bias=bias,
88
+ quant_config=quant_config,
89
+ prefix=f"{prefix}.down_proj",
90
+ )
91
+ self.act_fn = ReLU2()
92
+
93
+ def forward(self, x: torch.Tensor):
94
+ x, _ = self.up_proj(x)
95
+ x = self.act_fn(x)
96
+ x, _ = self.down_proj(x)
97
+ return x
98
+
99
+
100
+ class NemotronHMLPDecoderLayer(nn.Module):
101
+ def __init__(
102
+ self,
103
+ config: NemotronHConfig,
104
+ layer_idx: int,
105
+ quant_config: Optional[QuantizationConfig] = None,
106
+ prefix: str = "",
107
+ ) -> None:
108
+ super().__init__()
109
+ self.config = config
110
+
111
+ self.mixer = NemotronHMLP(
112
+ config,
113
+ quant_config=quant_config,
114
+ bias=config.mlp_bias,
115
+ prefix=f"{prefix}.mixer",
116
+ layer_idx=layer_idx,
117
+ )
118
+
119
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
120
+
121
+ def forward(
122
+ self,
123
+ *,
124
+ hidden_states: torch.Tensor,
125
+ residual: Optional[torch.Tensor],
126
+ forward_batch: ForwardBatch,
127
+ ) -> tuple[torch.Tensor, torch.Tensor]:
128
+ if residual is None:
129
+ residual = hidden_states
130
+ hidden_states = self.norm(hidden_states)
131
+ else:
132
+ hidden_states, residual = self.norm(hidden_states, residual)
133
+
134
+ hidden_states = self.mixer.forward(hidden_states)
135
+ return hidden_states, residual
136
+
137
+
138
+ class NemotronHMambaDecoderLayer(nn.Module):
139
+ def __init__(
140
+ self,
141
+ config: NemotronHConfig,
142
+ layer_idx: int,
143
+ quant_config: Optional[QuantizationConfig] = None,
144
+ prefix: str = "",
145
+ ) -> None:
146
+ super().__init__()
147
+ self.config = config
148
+ self.layer_id = layer_idx
149
+ self.mixer = MambaMixer2(
150
+ cache_params=config.mamba2_cache_params,
151
+ hidden_size=config.hidden_size,
152
+ use_conv_bias=config.use_conv_bias,
153
+ use_bias=config.use_bias,
154
+ n_groups=config.mamba_n_groups,
155
+ rms_norm_eps=config.rms_norm_eps,
156
+ activation=config.mamba_hidden_act,
157
+ quant_config=quant_config,
158
+ )
159
+
160
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
161
+
162
+ def forward(
163
+ self,
164
+ *,
165
+ hidden_states: torch.Tensor,
166
+ residual: Optional[torch.Tensor],
167
+ forward_batch: ForwardBatch,
168
+ ) -> tuple[torch.Tensor, torch.Tensor]:
169
+ if residual is None:
170
+ residual = hidden_states
171
+ hidden_states = self.norm(hidden_states)
172
+ else:
173
+ hidden_states, residual = self.norm(hidden_states, residual)
174
+
175
+ output = torch.empty_like(hidden_states)
176
+ attn_backend = forward_batch.attn_backend
177
+ assert isinstance(attn_backend, HybridLinearAttnBackend)
178
+ assert isinstance(attn_backend.linear_attn_backend, Mamba2AttnBackend)
179
+ attn_backend.linear_attn_backend.forward(
180
+ mixer=self.mixer,
181
+ layer_id=self.layer_id,
182
+ hidden_states=hidden_states,
183
+ output=output,
184
+ use_triton_causal_conv=True, # TODO: investigate need of `use_triton_causal_conv`
185
+ )
186
+ return output, residual
187
+
188
+
189
+ class NemotronHAttention(nn.Module):
190
+ def __init__(
191
+ self,
192
+ config: NemotronHConfig,
193
+ layer_idx: int,
194
+ quant_config: Optional[QuantizationConfig] = None,
195
+ prefix: str = "",
196
+ ) -> None:
197
+ super().__init__()
198
+ self.hidden_size = config.hidden_size
199
+ tp_size = get_tensor_model_parallel_world_size()
200
+ self.total_num_heads = config.num_attention_heads
201
+ assert self.total_num_heads % tp_size == 0
202
+ self.num_heads = self.total_num_heads // tp_size
203
+ self.total_num_kv_heads = config.num_key_value_heads
204
+ if self.total_num_kv_heads >= tp_size:
205
+ # Number of KV heads is greater than TP size, so we partition
206
+ # the KV heads across multiple tensor parallel GPUs.
207
+ assert self.total_num_kv_heads % tp_size == 0
208
+ else:
209
+ # Number of KV heads is less than TP size, so we replicate
210
+ # the KV heads across multiple tensor parallel GPUs.
211
+ assert tp_size % self.total_num_kv_heads == 0
212
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
213
+ if hasattr(config, "head_dim") and config.head_dim is not None:
214
+ self.head_dim = config.head_dim
215
+ else:
216
+ self.head_dim = config.hidden_size // self.total_num_heads
217
+ self.q_size = self.num_heads * self.head_dim
218
+ self.kv_size = self.num_kv_heads * self.head_dim
219
+ self.scaling = self.head_dim**-0.5
220
+
221
+ self.qkv_proj = QKVParallelLinear(
222
+ config.hidden_size,
223
+ self.head_dim,
224
+ self.total_num_heads,
225
+ self.total_num_kv_heads,
226
+ bias=False,
227
+ quant_config=quant_config,
228
+ prefix=f"{prefix}.qkv_proj",
229
+ )
230
+ self.o_proj = RowParallelLinear(
231
+ self.total_num_heads * self.head_dim,
232
+ config.hidden_size,
233
+ bias=False,
234
+ quant_config=quant_config,
235
+ prefix=f"{prefix}.o_proj",
236
+ )
237
+
238
+ self.attn = RadixAttention(
239
+ self.num_heads,
240
+ self.head_dim,
241
+ self.scaling,
242
+ num_kv_heads=self.num_kv_heads,
243
+ layer_id=layer_idx,
244
+ quant_config=quant_config,
245
+ prefix=add_prefix("attn", prefix),
246
+ )
247
+
248
+ def forward(
249
+ self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
250
+ ) -> torch.Tensor:
251
+ qkv, _ = self.qkv_proj(hidden_states)
252
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
253
+ attn_output = self.attn.forward(q, k, v, forward_batch)
254
+ output, _ = self.o_proj(attn_output)
255
+ return output
256
+
257
+
258
+ class NemotronHAttentionDecoderLayer(nn.Module):
259
+ def __init__(
260
+ self,
261
+ config: NemotronHConfig,
262
+ layer_idx: int,
263
+ quant_config: Optional[QuantizationConfig] = None,
264
+ prefix: str = "",
265
+ ) -> None:
266
+ super().__init__()
267
+
268
+ self.mixer = NemotronHAttention(
269
+ config,
270
+ layer_idx,
271
+ quant_config,
272
+ prefix=f"{prefix}.mixer",
273
+ )
274
+
275
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
276
+
277
+ def forward(
278
+ self,
279
+ *,
280
+ hidden_states: torch.Tensor,
281
+ residual: Optional[torch.Tensor],
282
+ forward_batch: ForwardBatch,
283
+ ) -> tuple[torch.Tensor, torch.Tensor]:
284
+ if residual is None:
285
+ residual = hidden_states
286
+ hidden_states = self.norm(hidden_states)
287
+ else:
288
+ hidden_states, residual = self.norm(hidden_states, residual)
289
+
290
+ hidden_states = self.mixer.forward(
291
+ hidden_states=hidden_states, forward_batch=forward_batch
292
+ )
293
+ return hidden_states, residual
294
+
295
+
296
+ Layers = (
297
+ NemotronHAttentionDecoderLayer
298
+ | NemotronHMLPDecoderLayer
299
+ | NemotronHMambaDecoderLayer
300
+ )
301
+ ALL_DECODER_LAYER_TYPES: dict[str, type[Layers]] = {
302
+ ATTENTION: NemotronHAttentionDecoderLayer,
303
+ MLP: NemotronHMLPDecoderLayer,
304
+ MAMBA: NemotronHMambaDecoderLayer,
305
+ }
306
+
307
+
308
+ class NemotronHModel(nn.Module):
309
+ def __init__(
310
+ self,
311
+ *,
312
+ config: NemotronHConfig,
313
+ quant_config: Optional[QuantizationConfig] = None,
314
+ prefix: str = "",
315
+ ):
316
+ super().__init__()
317
+
318
+ lora_config = None
319
+ self.config = config
320
+ lora_vocab = (
321
+ (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
322
+ if lora_config
323
+ else 0
324
+ )
325
+ self.vocab_size = config.vocab_size + lora_vocab
326
+ self.org_vocab_size = config.vocab_size
327
+
328
+ self.embed_tokens = VocabParallelEmbedding(
329
+ self.vocab_size,
330
+ config.hidden_size,
331
+ org_num_embeddings=config.vocab_size,
332
+ )
333
+
334
+ def get_layer(idx: int, prefix: str):
335
+ layer_class = ALL_DECODER_LAYER_TYPES[config.hybrid_override_pattern[idx]]
336
+ return layer_class(config, idx, quant_config=quant_config, prefix=prefix)
337
+
338
+ self.layers = make_layers_non_pp(
339
+ len(config.hybrid_override_pattern), get_layer, prefix=f"{prefix}.layers"
340
+ )
341
+ self.norm_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
342
+
343
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
344
+ return self.embed_tokens(input_ids)
345
+
346
+ def forward(
347
+ self,
348
+ input_ids: torch.Tensor,
349
+ positions: torch.Tensor,
350
+ forward_batch: ForwardBatch,
351
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
352
+ inputs_embeds: Optional[torch.Tensor] = None,
353
+ ) -> Union[torch.Tensor, PPProxyTensors]:
354
+ if get_pp_group().is_first_rank:
355
+ if inputs_embeds is not None:
356
+ hidden_states = inputs_embeds
357
+ else:
358
+ hidden_states = self.get_input_embeddings(input_ids)
359
+ residual = None
360
+ else:
361
+ assert pp_proxy_tensors is not None
362
+ hidden_states = pp_proxy_tensors["hidden_states"]
363
+ residual = pp_proxy_tensors["residual"]
364
+
365
+ residual = None
366
+ for layer in self.layers:
367
+ if not isinstance(layer, Layers):
368
+ raise ValueError(f"Unknown layer type: {type(layer)}")
369
+ hidden_states, residual = layer.forward(
370
+ hidden_states=hidden_states,
371
+ residual=residual,
372
+ forward_batch=forward_batch,
373
+ )
374
+
375
+ if not get_pp_group().is_last_rank:
376
+ return PPProxyTensors(
377
+ {"hidden_states": hidden_states, "residual": residual}
378
+ )
379
+ hidden_states, _ = self.norm_f(hidden_states, residual)
380
+ return hidden_states
381
+
382
+
383
+ class NemotronHForCausalLM(nn.Module):
384
+ remap_prefix = {"backbone": "model"}
385
+ remap_substr = {"A_log": "A", "embeddings": "embed_tokens"}
386
+
387
+ # LoRA specific attributes
388
+ embedding_modules = {
389
+ "embed_tokens": "input_embeddings",
390
+ "lm_head": "output_embeddings",
391
+ }
392
+ embedding_padding_modules = ["lm_head"]
393
+
394
+ def __init__(
395
+ self,
396
+ *,
397
+ config: NemotronHConfig,
398
+ quant_config: Optional[QuantizationConfig] = None,
399
+ prefix: str = "",
400
+ ):
401
+ super().__init__()
402
+ lora_config = None
403
+ self.config = config
404
+ self.model = self._init_model(
405
+ config=config, quant_config=quant_config, prefix=prefix
406
+ )
407
+ if self.config.tie_word_embeddings:
408
+ self.lm_head = self.model.embed_tokens
409
+ else:
410
+ self.unpadded_vocab_size = config.vocab_size
411
+ if lora_config:
412
+ self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
413
+ self.lm_head = ParallelLMHead(
414
+ self.unpadded_vocab_size,
415
+ config.hidden_size,
416
+ org_num_embeddings=config.vocab_size,
417
+ padding_size=(
418
+ DEFAULT_VOCAB_PADDING_SIZE
419
+ # We need bigger padding if using lora for kernel
420
+ # compatibility
421
+ if not lora_config
422
+ else lora_config.lora_vocab_padding_size
423
+ ),
424
+ quant_config=quant_config,
425
+ prefix=add_prefix("lm_head", prefix),
426
+ )
427
+ self.logits_processor = LogitsProcessor(config)
428
+
429
+ def _init_model(
430
+ self,
431
+ config: NemotronHConfig,
432
+ quant_config: Optional[QuantizationConfig] = None,
433
+ prefix: str = "",
434
+ ):
435
+ return NemotronHModel(config=config, quant_config=quant_config, prefix=prefix)
436
+
437
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
438
+ return self.model.get_input_embeddings(input_ids)
439
+
440
+ @torch.no_grad()
441
+ def forward(
442
+ self,
443
+ input_ids: torch.Tensor,
444
+ positions: torch.Tensor,
445
+ forward_batch: ForwardBatch,
446
+ input_embeds: Optional[torch.Tensor] = None,
447
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
448
+ ):
449
+ hidden_states = self.model.forward(
450
+ input_ids, positions, forward_batch, pp_proxy_tensors, input_embeds
451
+ )
452
+ return self.logits_processor(
453
+ input_ids, hidden_states, self.lm_head, forward_batch
454
+ )
455
+
456
+ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
457
+ return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs)
458
+
459
+ def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
460
+ return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
461
+
462
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None:
463
+ stacked_params_mapping = [
464
+ # (param_name, shard_name, shard_id)
465
+ ("qkv_proj", "q_proj", "q"),
466
+ ("qkv_proj", "k_proj", "k"),
467
+ ("qkv_proj", "v_proj", "v"),
468
+ ]
469
+
470
+ updated_weights = []
471
+ for name, loaded_weight in weights:
472
+ for prefix, new_key in self.remap_prefix.items():
473
+ if name.startswith(prefix):
474
+ name = name.replace(prefix, new_key)
475
+ for substr, new_key in self.remap_substr.items():
476
+ if substr in name:
477
+ name = name.replace(substr, new_key)
478
+ updated_weights.append((name, loaded_weight))
479
+ params_dict = dict(self.named_parameters())
480
+
481
+ for name, loaded_weight in updated_weights:
482
+ if "scale" in name:
483
+ name = maybe_remap_kv_scale_name(name, params_dict)
484
+ if name is None:
485
+ continue
486
+
487
+ for param_name, weight_name, shard_id in stacked_params_mapping:
488
+ if weight_name not in name:
489
+ continue
490
+ name = name.replace(weight_name, param_name)
491
+ # Skip loading extra bias for GPTQ models.
492
+ if name.endswith(".bias") and name not in params_dict:
493
+ continue
494
+ if name not in params_dict:
495
+ continue
496
+ param = params_dict[name]
497
+ weight_loader = param.weight_loader
498
+ weight_loader(param, loaded_weight, shard_id)
499
+ break
500
+ else:
501
+ # Skip loading extra bias for GPTQ models.
502
+ if name.endswith(".bias") and name not in params_dict:
503
+ continue
504
+ if name in params_dict.keys():
505
+ param = params_dict[name]
506
+ weight_loader = getattr(
507
+ param, "weight_loader", default_weight_loader
508
+ )
509
+ weight_loader(param, loaded_weight)
510
+ else:
511
+ logger.warning(f"Parameter {name} not found in params_dict")
512
+
513
+
514
+ EntryClass = [NemotronHForCausalLM]
@@ -27,7 +27,11 @@ if _is_cuda:
27
27
 
28
28
  def enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
29
29
  """Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
30
- return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
30
+ return (
31
+ _is_cuda
32
+ and hasattr(forward_batch.token_to_kv_pool, "dtype")
33
+ and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
34
+ )
31
35
 
32
36
 
33
37
  def create_fused_set_kv_buffer_arg(
@@ -44,12 +44,9 @@ class SamplingBatchInfo:
44
44
  vocab_mask: Optional[torch.Tensor] = None
45
45
  apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
46
46
 
47
- # An event used for overlap schedule
48
- sampling_info_done: Optional[threading.Event] = None
49
-
50
47
  # Penalizer
51
48
  penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
52
- linear_penalty: torch.Tensor = None
49
+ acc_linear_penalties: torch.Tensor = None # Used in the overlap mode
53
50
 
54
51
  # Whether any request has custom logit processor
55
52
  has_custom_logit_processor: bool = False
@@ -217,19 +214,19 @@ class SamplingBatchInfo:
217
214
 
218
215
  def update_penalties(self):
219
216
  if self.penalizer_orchestrator.is_required:
220
- self.linear_penalty = torch.zeros(
217
+ self.acc_linear_penalties = torch.zeros(
221
218
  (len(self.temperatures), self.vocab_size),
222
219
  dtype=torch.float32,
223
220
  device=self.temperatures.device,
224
221
  )
225
- self.penalizer_orchestrator.apply(self.linear_penalty)
222
+ self.penalizer_orchestrator.apply(self.acc_linear_penalties)
226
223
  else:
227
- self.linear_penalty = None
224
+ self.acc_linear_penalties = None
228
225
 
229
226
  def apply_logits_bias(self, logits: torch.Tensor):
230
- if self.linear_penalty is not None:
227
+ if self.acc_linear_penalties is not None:
231
228
  # Used in the overlap mode
232
- logits.add_(self.linear_penalty)
229
+ logits.add_(self.acc_linear_penalties)
233
230
 
234
231
  if self.penalizer_orchestrator and self.penalizer_orchestrator.is_required:
235
232
  # Used in the non-overlap mode
@@ -370,6 +367,11 @@ class SamplingBatchInfo:
370
367
  self.need_top_k_sampling |= other.need_top_k_sampling
371
368
  self.need_min_p_sampling |= other.need_min_p_sampling
372
369
 
370
+ def copy_for_forward(self):
371
+ # Accumulate the penalty into a pre-allocated buffer to get rid of the dependency of `penalizer_orchestrator` later
372
+ self.update_penalties()
373
+ return dataclasses.replace(self, penalizer_orchestrator=None)
374
+
373
375
 
374
376
  def merge_bias_tensor(
375
377
  lhs: Optional[torch.Tensor],