sglang 0.2.14.post2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. sglang/api.py +2 -0
  2. sglang/bench_latency.py +39 -28
  3. sglang/lang/backend/runtime_endpoint.py +8 -4
  4. sglang/lang/interpreter.py +3 -0
  5. sglang/lang/ir.py +5 -0
  6. sglang/launch_server_llavavid.py +12 -12
  7. sglang/srt/configs/__init__.py +5 -0
  8. sglang/srt/configs/exaone.py +195 -0
  9. sglang/srt/constrained/fsm_cache.py +1 -1
  10. sglang/srt/conversation.py +24 -2
  11. sglang/srt/hf_transformers_utils.py +12 -12
  12. sglang/srt/layers/extend_attention.py +13 -8
  13. sglang/srt/layers/logits_processor.py +4 -4
  14. sglang/srt/layers/sampler.py +94 -17
  15. sglang/srt/managers/controller_multi.py +5 -5
  16. sglang/srt/managers/controller_single.py +5 -5
  17. sglang/srt/managers/io_struct.py +6 -1
  18. sglang/srt/managers/schedule_batch.py +26 -11
  19. sglang/srt/managers/tokenizer_manager.py +9 -9
  20. sglang/srt/managers/tp_worker.py +38 -26
  21. sglang/srt/model_config.py +3 -3
  22. sglang/srt/model_executor/cuda_graph_runner.py +26 -9
  23. sglang/srt/model_executor/forward_batch_info.py +68 -23
  24. sglang/srt/model_executor/model_runner.py +15 -22
  25. sglang/srt/models/chatglm.py +9 -15
  26. sglang/srt/models/commandr.py +5 -1
  27. sglang/srt/models/dbrx.py +5 -1
  28. sglang/srt/models/deepseek.py +5 -1
  29. sglang/srt/models/deepseek_v2.py +57 -25
  30. sglang/srt/models/exaone.py +368 -0
  31. sglang/srt/models/gemma.py +5 -1
  32. sglang/srt/models/gemma2.py +5 -1
  33. sglang/srt/models/gpt_bigcode.py +5 -1
  34. sglang/srt/models/grok.py +5 -1
  35. sglang/srt/models/internlm2.py +5 -1
  36. sglang/srt/models/{llama2.py → llama.py} +25 -45
  37. sglang/srt/models/llama_classification.py +34 -41
  38. sglang/srt/models/llama_embedding.py +7 -6
  39. sglang/srt/models/llava.py +8 -11
  40. sglang/srt/models/llavavid.py +5 -6
  41. sglang/srt/models/minicpm.py +5 -1
  42. sglang/srt/models/mistral.py +2 -3
  43. sglang/srt/models/mixtral.py +6 -2
  44. sglang/srt/models/mixtral_quant.py +5 -1
  45. sglang/srt/models/qwen.py +5 -2
  46. sglang/srt/models/qwen2.py +6 -2
  47. sglang/srt/models/qwen2_moe.py +5 -14
  48. sglang/srt/models/stablelm.py +5 -1
  49. sglang/srt/openai_api/adapter.py +16 -1
  50. sglang/srt/openai_api/protocol.py +5 -5
  51. sglang/srt/sampling/sampling_batch_info.py +75 -6
  52. sglang/srt/server.py +6 -6
  53. sglang/srt/utils.py +0 -3
  54. sglang/test/runners.py +1 -1
  55. sglang/test/test_programs.py +68 -0
  56. sglang/test/test_utils.py +4 -0
  57. sglang/utils.py +39 -0
  58. sglang/version.py +1 -1
  59. {sglang-0.2.14.post2.dist-info → sglang-0.3.0.dist-info}/METADATA +9 -8
  60. sglang-0.3.0.dist-info/RECORD +118 -0
  61. {sglang-0.2.14.post2.dist-info → sglang-0.3.0.dist-info}/WHEEL +1 -1
  62. sglang-0.2.14.post2.dist-info/RECORD +0 -115
  63. {sglang-0.2.14.post2.dist-info → sglang-0.3.0.dist-info}/LICENSE +0 -0
  64. {sglang-0.2.14.post2.dist-info → sglang-0.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,368 @@
1
+ """
2
+ Copyright 2024 The LGcns AI Engineering Team
3
+ Copyright 2023-2024 SGLang Team
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ # Adapted from llama2.py
18
+ """Inference-only Exaone model compatible with HuggingFace weights."""
19
+
20
+ from typing import Any, Dict, Iterable, Optional, Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+ from vllm.config import CacheConfig
25
+ from vllm.distributed import get_tensor_model_parallel_world_size
26
+ from vllm.model_executor.layers.linear import (
27
+ MergedColumnParallelLinear,
28
+ QKVParallelLinear,
29
+ RowParallelLinear,
30
+ )
31
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
32
+ from vllm.model_executor.layers.rotary_embedding import get_rope
33
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
34
+ ParallelLMHead,
35
+ VocabParallelEmbedding,
36
+ )
37
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
38
+
39
+ from sglang.srt.layers.activation import SiluAndMul
40
+ from sglang.srt.layers.layernorm import RMSNorm
41
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
42
+ from sglang.srt.layers.radix_attention import RadixAttention
43
+ from sglang.srt.layers.sampler import Sampler
44
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
+
46
+
47
+ class ExaoneGatedMLP(nn.Module):
48
+ def __init__(
49
+ self,
50
+ hidden_size: int,
51
+ intermediate_size: int,
52
+ hidden_act: str,
53
+ quant_config: Optional[QuantizationConfig] = None,
54
+ prefix: str = "",
55
+ ) -> None:
56
+ super().__init__()
57
+ self.gate_up_proj = MergedColumnParallelLinear(
58
+ hidden_size,
59
+ [intermediate_size] * 2,
60
+ bias=False,
61
+ quant_config=quant_config,
62
+ prefix=f"{prefix}.gate_up_proj",
63
+ )
64
+ self.c_proj = RowParallelLinear(
65
+ intermediate_size,
66
+ hidden_size,
67
+ bias=False,
68
+ quant_config=quant_config,
69
+ prefix=f"{prefix}.c_proj",
70
+ )
71
+ if hidden_act != "silu":
72
+ raise ValueError(
73
+ f"Unsupported activation: {hidden_act}. "
74
+ "Only silu is supported for now."
75
+ )
76
+ self.act_fn = SiluAndMul()
77
+
78
+ def forward(self, x):
79
+ gate_up, _ = self.gate_up_proj(x)
80
+ x = self.act_fn(gate_up)
81
+ x, _ = self.c_proj(x)
82
+ return x
83
+
84
+
85
+ class ExaoneAttention(nn.Module):
86
+ def __init__(
87
+ self,
88
+ config,
89
+ hidden_size: int,
90
+ num_heads: int,
91
+ num_kv_heads: int,
92
+ layer_id: int = 0,
93
+ rope_theta: float = 500000,
94
+ rope_scaling: Optional[Dict[str, Any]] = None,
95
+ rope_is_neox_style: bool = True,
96
+ max_position_embeddings: int = 4096,
97
+ quant_config: Optional[QuantizationConfig] = None,
98
+ prefix: str = "",
99
+ ) -> None:
100
+ super().__init__()
101
+ self.hidden_size = hidden_size
102
+ tp_size = get_tensor_model_parallel_world_size()
103
+ self.total_num_heads = num_heads
104
+ assert self.total_num_heads % tp_size == 0
105
+ self.num_heads = self.total_num_heads // tp_size
106
+ self.total_num_kv_heads = num_kv_heads
107
+ if self.total_num_kv_heads >= tp_size:
108
+ # Number of KV heads is greater than TP size, so we partition
109
+ # the KV heads across multiple tensor parallel GPUs.
110
+ assert self.total_num_kv_heads % tp_size == 0
111
+ else:
112
+ # Number of KV heads is less than TP size, so we replicate
113
+ # the KV heads across multiple tensor parallel GPUs.
114
+ assert tp_size % self.total_num_kv_heads == 0
115
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
116
+ # MistralConfig has an optional head_dim introduced by Mistral-Nemo
117
+ self.head_dim = getattr(
118
+ config, "head_dim", self.hidden_size // self.total_num_heads
119
+ )
120
+ self.rotary_dim = int(
121
+ self.head_dim * getattr(config, "partial_rotary_factor", 1)
122
+ )
123
+ self.q_size = self.num_heads * self.head_dim
124
+ self.kv_size = self.num_kv_heads * self.head_dim
125
+ self.scaling = self.head_dim**-0.5
126
+ self.rope_theta = rope_theta
127
+ self.max_position_embeddings = max_position_embeddings
128
+
129
+ self.qkv_proj = QKVParallelLinear(
130
+ hidden_size,
131
+ self.head_dim,
132
+ self.total_num_heads,
133
+ self.total_num_kv_heads,
134
+ bias=False,
135
+ quant_config=quant_config,
136
+ prefix=f"{prefix}.qkv_proj",
137
+ )
138
+ self.out_proj = RowParallelLinear(
139
+ self.total_num_heads * self.head_dim,
140
+ hidden_size,
141
+ bias=False,
142
+ quant_config=quant_config,
143
+ prefix=f"{prefix}.out_proj",
144
+ )
145
+
146
+ self.rotary_emb = get_rope(
147
+ self.head_dim,
148
+ rotary_dim=self.rotary_dim,
149
+ max_position=max_position_embeddings,
150
+ base=rope_theta,
151
+ rope_scaling=rope_scaling,
152
+ is_neox_style=rope_is_neox_style,
153
+ )
154
+ self.attn = RadixAttention(
155
+ self.num_heads,
156
+ self.head_dim,
157
+ self.scaling,
158
+ num_kv_heads=self.num_kv_heads,
159
+ layer_id=layer_id,
160
+ )
161
+
162
+ def forward(
163
+ self,
164
+ positions: torch.Tensor,
165
+ hidden_states: torch.Tensor,
166
+ input_metadata: InputMetadata,
167
+ ) -> torch.Tensor:
168
+ qkv, _ = self.qkv_proj(hidden_states)
169
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
170
+ q, k = self.rotary_emb(positions, q, k)
171
+ attn_output = self.attn(q, k, v, input_metadata)
172
+ output, _ = self.out_proj(attn_output)
173
+ return output
174
+
175
+
176
+ class ExaoneDecoderLayer(nn.Module):
177
+ def __init__(
178
+ self,
179
+ config,
180
+ layer_id: int = 0,
181
+ quant_config: Optional[QuantizationConfig] = None,
182
+ prefix: str = "",
183
+ ) -> None:
184
+ super().__init__()
185
+ self.hidden_size = config.hidden_size
186
+ rope_theta = getattr(config, "rope_theta", 500000)
187
+ rope_scaling = getattr(config, "rope_scaling", None)
188
+ if rope_scaling is not None and getattr(
189
+ config, "original_max_position_embeddings", None
190
+ ):
191
+ rope_scaling["original_max_position_embeddings"] = (
192
+ config.original_max_position_embeddings
193
+ )
194
+ rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
195
+ max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
196
+ self.self_attn = ExaoneAttention(
197
+ config=config,
198
+ hidden_size=self.hidden_size,
199
+ num_heads=config.num_attention_heads,
200
+ num_kv_heads=config.num_key_value_heads,
201
+ layer_id=layer_id,
202
+ rope_theta=rope_theta,
203
+ rope_scaling=rope_scaling,
204
+ rope_is_neox_style=rope_is_neox_style,
205
+ max_position_embeddings=max_position_embeddings,
206
+ quant_config=quant_config,
207
+ prefix=f"{prefix}.self_attn",
208
+ )
209
+ self.mlp = ExaoneGatedMLP(
210
+ hidden_size=self.hidden_size,
211
+ intermediate_size=config.intermediate_size,
212
+ hidden_act=config.activation_function,
213
+ quant_config=quant_config,
214
+ prefix=f"{prefix}.mlp",
215
+ )
216
+ rms_norm_eps = config.layer_norm_epsilon
217
+ self.ln_1 = RMSNorm(config.hidden_size, eps=rms_norm_eps)
218
+ self.ln_2 = RMSNorm(config.hidden_size, eps=rms_norm_eps)
219
+
220
+ def forward(
221
+ self,
222
+ positions: torch.Tensor,
223
+ hidden_states: torch.Tensor,
224
+ input_metadata: InputMetadata,
225
+ residual: Optional[torch.Tensor],
226
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
227
+ # Self Attention
228
+ if residual is None:
229
+ residual = hidden_states
230
+ hidden_states = self.ln_1(hidden_states)
231
+ else:
232
+ hidden_states, residual = self.ln_1(hidden_states, residual)
233
+ hidden_states = self.self_attn(
234
+ positions=positions,
235
+ hidden_states=hidden_states,
236
+ input_metadata=input_metadata,
237
+ )
238
+
239
+ # Fully Connected
240
+ hidden_states, residual = self.ln_2(hidden_states, residual)
241
+ hidden_states = self.mlp(hidden_states)
242
+ return hidden_states, residual
243
+
244
+
245
+ class ExaoneModel(nn.Module):
246
+ def __init__(
247
+ self,
248
+ config,
249
+ quant_config: Optional[QuantizationConfig] = None,
250
+ ) -> None:
251
+ super().__init__()
252
+ self.config = config
253
+ self.padding_idx = config.pad_token_id
254
+ self.vocab_size = config.vocab_size
255
+ self.wte = VocabParallelEmbedding(
256
+ config.vocab_size,
257
+ config.hidden_size,
258
+ )
259
+ self.h = nn.ModuleList(
260
+ [
261
+ ExaoneDecoderLayer(
262
+ config, i, quant_config=quant_config, prefix=f"model.h.{i}"
263
+ )
264
+ for i in range(config.num_hidden_layers)
265
+ ]
266
+ )
267
+ rms_norm_eps = config.layer_norm_epsilon
268
+ self.ln_f = RMSNorm(config.hidden_size, eps=rms_norm_eps)
269
+
270
+ def forward(
271
+ self,
272
+ input_ids: torch.Tensor,
273
+ positions: torch.Tensor,
274
+ input_metadata: InputMetadata,
275
+ input_embeds: torch.Tensor = None,
276
+ ) -> torch.Tensor:
277
+ if input_embeds is None:
278
+ hidden_states = self.wte(input_ids)
279
+ else:
280
+ hidden_states = input_embeds
281
+ residual = None
282
+ for i in range(len(self.h)):
283
+ layer = self.h[i]
284
+ hidden_states, residual = layer(
285
+ positions,
286
+ hidden_states,
287
+ input_metadata,
288
+ residual,
289
+ )
290
+ hidden_states, _ = self.ln_f(hidden_states, residual)
291
+ return hidden_states
292
+
293
+
294
+ class ExaoneForCausalLM(nn.Module):
295
+ def __init__(
296
+ self,
297
+ config,
298
+ quant_config: Optional[QuantizationConfig] = None,
299
+ cache_config: Optional[CacheConfig] = None,
300
+ ) -> None:
301
+ super().__init__()
302
+ self.config = config
303
+ self.quant_config = quant_config
304
+ self.transformer = ExaoneModel(config, quant_config=quant_config)
305
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
306
+ self.logits_processor = LogitsProcessor(config)
307
+ self.sampler = Sampler()
308
+
309
+ @torch.no_grad()
310
+ def forward(
311
+ self,
312
+ input_ids: torch.Tensor,
313
+ positions: torch.Tensor,
314
+ input_metadata: InputMetadata,
315
+ input_embeds: torch.Tensor = None,
316
+ ) -> LogitsProcessorOutput:
317
+ hidden_states = self.transformer(
318
+ input_ids, positions, input_metadata, input_embeds
319
+ )
320
+ logits_output = self.logits_processor(
321
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
322
+ )
323
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
324
+ return sample_output, logits_output
325
+
326
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
327
+ stacked_params_mapping = [
328
+ # (param_name, shard_name, shard_id)
329
+ ("qkv_proj", "q_proj", "q"),
330
+ ("qkv_proj", "k_proj", "k"),
331
+ ("qkv_proj", "v_proj", "v"),
332
+ ("gate_up_proj", "c_fc_0", 0),
333
+ ("gate_up_proj", "c_fc_1", 1),
334
+ ]
335
+ params_dict = dict(self.named_parameters())
336
+
337
+ for name, loaded_weight in weights:
338
+ if "rotary_emb.inv_freq" in name or "projector" in name:
339
+ continue
340
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
341
+ # Models trained using ColossalAI may include these tensors in
342
+ # the checkpoint. Skip them.
343
+ continue
344
+ if name.startswith("model.vision_tower") and name not in params_dict:
345
+ continue
346
+
347
+ name = name.replace("attn.attention", "self_attn")
348
+ for param_name, weight_name, shard_id in stacked_params_mapping:
349
+ if weight_name not in name:
350
+ continue
351
+ name = name.replace(weight_name, param_name)
352
+ # Skip loading extra bias for GPTQ models.
353
+ if name.endswith(".bias") and name not in params_dict:
354
+ continue
355
+ param = params_dict[name]
356
+ weight_loader = param.weight_loader
357
+ weight_loader(param, loaded_weight, shard_id)
358
+ break
359
+ else:
360
+ # Skip loading extra bias for GPTQ models.
361
+ if name.endswith(".bias") and name not in params_dict:
362
+ continue
363
+ param = params_dict[name]
364
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
365
+ weight_loader(param, loaded_weight)
366
+
367
+
368
+ EntryClass = ExaoneForCausalLM
@@ -37,6 +37,7 @@ from sglang.srt.layers.activation import GeluAndMul
37
37
  from sglang.srt.layers.layernorm import RMSNorm
38
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
+ from sglang.srt.layers.sampler import Sampler
40
41
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
41
42
 
42
43
 
@@ -287,6 +288,7 @@ class GemmaForCausalLM(nn.Module):
287
288
  self.quant_config = quant_config
288
289
  self.model = GemmaModel(config, quant_config=quant_config)
289
290
  self.logits_processor = LogitsProcessor(config)
291
+ self.sampler = Sampler()
290
292
 
291
293
  @torch.no_grad()
292
294
  def forward(
@@ -297,9 +299,11 @@ class GemmaForCausalLM(nn.Module):
297
299
  input_embeds: torch.Tensor = None,
298
300
  ) -> torch.Tensor:
299
301
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
300
- return self.logits_processor(
302
+ logits_output = self.logits_processor(
301
303
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
302
304
  )
305
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
306
+ return (sample_output, logits_output)
303
307
 
304
308
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
305
309
  stacked_params_mapping = [
@@ -37,6 +37,7 @@ from sglang.srt.layers.activation import GeluAndMul
37
37
  from sglang.srt.layers.layernorm import GemmaRMSNorm
38
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
+ from sglang.srt.layers.sampler import Sampler
40
41
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
41
42
 
42
43
 
@@ -346,6 +347,7 @@ class Gemma2ForCausalLM(nn.Module):
346
347
  self.quant_config = quant_config
347
348
  self.model = Gemma2Model(config, cache_config, quant_config)
348
349
  self.logits_processor = LogitsProcessor(config)
350
+ self.sampler = Sampler()
349
351
 
350
352
  @torch.no_grad()
351
353
  def forward(
@@ -356,9 +358,11 @@ class Gemma2ForCausalLM(nn.Module):
356
358
  input_embeds: torch.Tensor = None,
357
359
  ) -> torch.Tensor:
358
360
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
359
- return self.logits_processor(
361
+ logits_output = self.logits_processor(
360
362
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
361
363
  )
364
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
365
+ return sample_output, logits_output
362
366
 
363
367
  def get_attention_sliding_window_size(self):
364
368
  return get_attention_sliding_window_size(self.config)
@@ -35,6 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
35
  from sglang.srt.layers.activation import get_act_fn
36
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
+ from sglang.srt.layers.sampler import Sampler
38
39
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
39
40
 
40
41
 
@@ -261,6 +262,7 @@ class GPTBigCodeForCausalLM(nn.Module):
261
262
  if lora_config:
262
263
  self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
263
264
  self.logits_processor = LogitsProcessor(config)
265
+ self.sampler = Sampler()
264
266
 
265
267
  @torch.no_grad()
266
268
  def forward(
@@ -270,9 +272,11 @@ class GPTBigCodeForCausalLM(nn.Module):
270
272
  input_metadata: InputMetadata,
271
273
  ) -> torch.Tensor:
272
274
  hidden_states = self.transformer(input_ids, positions, input_metadata)
273
- return self.logits_processor(
275
+ logits_output = self.logits_processor(
274
276
  input_ids, hidden_states, self.lm_head.weight, input_metadata
275
277
  )
278
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
279
+ return sample_output, logits_output
276
280
 
277
281
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
278
282
  params_dict = dict(self.named_parameters(remove_duplicate=False))
sglang/srt/models/grok.py CHANGED
@@ -46,6 +46,7 @@ from sglang.srt.layers.fused_moe import FusedMoE
46
46
  from sglang.srt.layers.layernorm import RMSNorm
47
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
+ from sglang.srt.layers.sampler import Sampler
49
50
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
51
 
51
52
 
@@ -297,6 +298,7 @@ class Grok1ForCausalLM(nn.Module):
297
298
  self.model = Grok1Model(config, quant_config=quant_config)
298
299
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
299
300
  self.logits_processor = LogitsProcessor(config)
301
+ self.sampler = Sampler()
300
302
 
301
303
  # Monkey patch _prepare_weights to load pre-sharded weights
302
304
  setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
@@ -313,9 +315,11 @@ class Grok1ForCausalLM(nn.Module):
313
315
  input_embeds: torch.Tensor = None,
314
316
  ) -> torch.Tensor:
315
317
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
316
- return self.logits_processor(
318
+ logits_output = self.logits_processor(
317
319
  input_ids, hidden_states, self.lm_head.weight, input_metadata
318
320
  )
321
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
322
+ return sample_output, logits_output
319
323
 
320
324
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
321
325
  stacked_params_mapping = [
@@ -40,6 +40,7 @@ from sglang.srt.layers.activation import SiluAndMul
40
40
  from sglang.srt.layers.layernorm import RMSNorm
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
+ from sglang.srt.layers.sampler import Sampler
43
44
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
44
45
 
45
46
 
@@ -262,6 +263,7 @@ class InternLM2ForCausalLM(nn.Module):
262
263
  self.model = InternLM2Model(config, quant_config)
263
264
  self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
264
265
  self.logits_processor = LogitsProcessor(config)
266
+ self.sampler = Sampler()
265
267
 
266
268
  @torch.no_grad()
267
269
  def forward(
@@ -272,9 +274,11 @@ class InternLM2ForCausalLM(nn.Module):
272
274
  input_embeds: torch.Tensor = None,
273
275
  ) -> torch.Tensor:
274
276
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
275
- return self.logits_processor(
277
+ logits_output = self.logits_processor(
276
278
  input_ids, hidden_states, self.output.weight, input_metadata
277
279
  )
280
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
281
+ return sample_output, logits_output
278
282
 
279
283
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
280
284
  stacked_params_mapping = [
@@ -39,8 +39,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
39
 
40
40
  from sglang.srt.layers.activation import SiluAndMul
41
41
  from sglang.srt.layers.layernorm import RMSNorm
42
- from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor
42
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
+ from sglang.srt.layers.sampler import Sampler
44
45
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
46
 
46
47
 
@@ -294,7 +295,6 @@ class LlamaForCausalLM(nn.Module):
294
295
  config: LlamaConfig,
295
296
  quant_config: Optional[QuantizationConfig] = None,
296
297
  cache_config: Optional[CacheConfig] = None,
297
- efficient_weight_load=False,
298
298
  ) -> None:
299
299
  super().__init__()
300
300
  self.config = config
@@ -302,6 +302,9 @@ class LlamaForCausalLM(nn.Module):
302
302
  self.model = LlamaModel(config, quant_config=quant_config)
303
303
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
304
304
  self.logits_processor = LogitsProcessor(config)
305
+ self.sampler = Sampler()
306
+
307
+ self.param_dict = dict(self.named_parameters())
305
308
 
306
309
  @torch.no_grad()
307
310
  def forward(
@@ -310,55 +313,34 @@ class LlamaForCausalLM(nn.Module):
310
313
  positions: torch.Tensor,
311
314
  input_metadata: InputMetadata,
312
315
  input_embeds: torch.Tensor = None,
313
- ) -> LogitProcessorOutput:
316
+ ) -> LogitsProcessorOutput:
314
317
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
315
- return self.logits_processor(
318
+ logits_output = self.logits_processor(
316
319
  input_ids, hidden_states, self.lm_head.weight, input_metadata
317
320
  )
321
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
322
+ return sample_output, logits_output
318
323
 
319
- def get_module_name(self, name):
320
- stacked_params_mapping = [
321
- # (param_name, shard_name, shard_id, num_shard)
322
- ("qkv_proj", "q_proj", "q", 3),
323
- ("qkv_proj", "k_proj", "k", 3),
324
- ("qkv_proj", "v_proj", "v", 3),
325
- ("gate_up_proj", "gate_proj", 0, 2),
326
- ("gate_up_proj", "up_proj", 1, 2),
327
- ]
328
- for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
329
- if weight_name in name:
330
- return (
331
- name.replace(weight_name, param_name)[: -len(".weight")],
332
- num_shard,
333
- )
334
- return name[: -len(".weight")], 1
335
-
336
- def get_num_params(self):
337
- params_dict = dict(self.named_parameters())
338
- return len(params_dict)
339
-
340
- def load_weights(
341
- self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
342
- ):
324
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
343
325
  stacked_params_mapping = [
344
326
  # (param_name, shard_name, shard_id)
345
- ("qkv_proj", "q_proj", "q"),
346
- ("qkv_proj", "k_proj", "k"),
347
- ("qkv_proj", "v_proj", "v"),
348
- ("gate_up_proj", "gate_proj", 0),
349
- ("gate_up_proj", "up_proj", 1),
327
+ (".qkv_proj", ".q_proj", "q"),
328
+ (".qkv_proj", ".k_proj", "k"),
329
+ (".qkv_proj", ".v_proj", "v"),
330
+ (".gate_up_proj", ".gate_proj", 0),
331
+ (".gate_up_proj", ".up_proj", 1),
350
332
  ]
351
- params_dict = dict(self.named_parameters())
333
+ params_dict = self.param_dict
352
334
 
353
- def load_weights_per_param(name, loaded_weight):
335
+ for name, loaded_weight in weights:
354
336
  if "rotary_emb.inv_freq" in name or "projector" in name:
355
- return
337
+ continue
356
338
  if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
357
339
  # Models trained using ColossalAI may include these tensors in
358
340
  # the checkpoint. Skip them.
359
- return
341
+ continue
360
342
  if name.startswith("model.vision_tower") and name not in params_dict:
361
- return
343
+ continue
362
344
 
363
345
  for param_name, weight_name, shard_id in stacked_params_mapping:
364
346
  if weight_name not in name:
@@ -374,16 +356,14 @@ class LlamaForCausalLM(nn.Module):
374
356
  else:
375
357
  # Skip loading extra bias for GPTQ models.
376
358
  if name.endswith(".bias") and name not in params_dict:
377
- return
359
+ continue
378
360
  param = params_dict[name]
379
361
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
380
362
  weight_loader(param, loaded_weight)
381
363
 
382
- if name is None or loaded_weight is None:
383
- for name, loaded_weight in weights:
384
- load_weights_per_param(name, loaded_weight)
385
- else:
386
- load_weights_per_param(name, loaded_weight)
364
+
365
+ class Phi3ForCausalLM(LlamaForCausalLM):
366
+ pass
387
367
 
388
368
 
389
- EntryClass = LlamaForCausalLM
369
+ EntryClass = [LlamaForCausalLM, Phi3ForCausalLM]