sglang 0.2.14.post2__py3-none-any.whl → 0.2.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. sglang/api.py +2 -0
  2. sglang/bench_latency.py +39 -28
  3. sglang/lang/interpreter.py +3 -0
  4. sglang/lang/ir.py +5 -0
  5. sglang/launch_server_llavavid.py +12 -12
  6. sglang/srt/configs/__init__.py +5 -0
  7. sglang/srt/configs/exaone.py +195 -0
  8. sglang/srt/constrained/fsm_cache.py +1 -1
  9. sglang/srt/conversation.py +24 -2
  10. sglang/srt/hf_transformers_utils.py +11 -11
  11. sglang/srt/layers/extend_attention.py +13 -8
  12. sglang/srt/layers/logits_processor.py +4 -4
  13. sglang/srt/layers/sampler.py +69 -16
  14. sglang/srt/managers/controller_multi.py +5 -5
  15. sglang/srt/managers/controller_single.py +5 -5
  16. sglang/srt/managers/io_struct.py +6 -1
  17. sglang/srt/managers/schedule_batch.py +20 -8
  18. sglang/srt/managers/tokenizer_manager.py +2 -2
  19. sglang/srt/managers/tp_worker.py +38 -26
  20. sglang/srt/model_config.py +3 -3
  21. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  22. sglang/srt/model_executor/forward_batch_info.py +68 -23
  23. sglang/srt/model_executor/model_runner.py +14 -12
  24. sglang/srt/models/chatglm.py +4 -12
  25. sglang/srt/models/commandr.py +5 -1
  26. sglang/srt/models/dbrx.py +5 -1
  27. sglang/srt/models/deepseek.py +5 -1
  28. sglang/srt/models/deepseek_v2.py +57 -25
  29. sglang/srt/models/exaone.py +399 -0
  30. sglang/srt/models/gemma.py +5 -1
  31. sglang/srt/models/gemma2.py +5 -1
  32. sglang/srt/models/gpt_bigcode.py +5 -1
  33. sglang/srt/models/grok.py +5 -1
  34. sglang/srt/models/internlm2.py +5 -1
  35. sglang/srt/models/llama2.py +7 -3
  36. sglang/srt/models/llama_classification.py +2 -2
  37. sglang/srt/models/minicpm.py +5 -1
  38. sglang/srt/models/mixtral.py +6 -2
  39. sglang/srt/models/mixtral_quant.py +5 -1
  40. sglang/srt/models/qwen.py +5 -2
  41. sglang/srt/models/qwen2.py +6 -2
  42. sglang/srt/models/qwen2_moe.py +5 -14
  43. sglang/srt/models/stablelm.py +5 -1
  44. sglang/srt/openai_api/adapter.py +16 -1
  45. sglang/srt/openai_api/protocol.py +5 -5
  46. sglang/srt/sampling/sampling_batch_info.py +79 -6
  47. sglang/srt/server.py +6 -6
  48. sglang/srt/utils.py +0 -3
  49. sglang/test/runners.py +1 -1
  50. sglang/version.py +1 -1
  51. {sglang-0.2.14.post2.dist-info → sglang-0.2.15.dist-info}/METADATA +7 -7
  52. {sglang-0.2.14.post2.dist-info → sglang-0.2.15.dist-info}/RECORD +55 -52
  53. {sglang-0.2.14.post2.dist-info → sglang-0.2.15.dist-info}/LICENSE +0 -0
  54. {sglang-0.2.14.post2.dist-info → sglang-0.2.15.dist-info}/WHEEL +0 -0
  55. {sglang-0.2.14.post2.dist-info → sglang-0.2.15.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,399 @@
1
+ """
2
+ Copyright 2024 The LGcns AI Engineering Team
3
+ Copyright 2023-2024 SGLang Team
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ # Adapted from llama2.py
18
+ """Inference-only Exaone model compatible with HuggingFace weights."""
19
+
20
+ from typing import Any, Dict, Iterable, Optional, Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+ from vllm.config import CacheConfig
25
+ from vllm.distributed import get_tensor_model_parallel_world_size
26
+ from vllm.model_executor.layers.linear import (
27
+ MergedColumnParallelLinear,
28
+ QKVParallelLinear,
29
+ RowParallelLinear,
30
+ )
31
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
32
+ from vllm.model_executor.layers.rotary_embedding import get_rope
33
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
34
+ ParallelLMHead,
35
+ VocabParallelEmbedding,
36
+ )
37
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
38
+
39
+ from sglang.srt.layers.activation import SiluAndMul
40
+ from sglang.srt.layers.layernorm import RMSNorm
41
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
42
+ from sglang.srt.layers.radix_attention import RadixAttention
43
+ from sglang.srt.layers.sampler import Sampler
44
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
+
46
+
47
+ class ExaoneGatedMLP(nn.Module):
48
+ def __init__(
49
+ self,
50
+ hidden_size: int,
51
+ intermediate_size: int,
52
+ hidden_act: str,
53
+ quant_config: Optional[QuantizationConfig] = None,
54
+ prefix: str = "",
55
+ ) -> None:
56
+ super().__init__()
57
+ self.gate_up_proj = MergedColumnParallelLinear(
58
+ hidden_size,
59
+ [intermediate_size] * 2,
60
+ bias=False,
61
+ quant_config=quant_config,
62
+ prefix=f"{prefix}.gate_up_proj",
63
+ )
64
+ self.c_proj = RowParallelLinear(
65
+ intermediate_size,
66
+ hidden_size,
67
+ bias=False,
68
+ quant_config=quant_config,
69
+ prefix=f"{prefix}.c_proj",
70
+ )
71
+ if hidden_act != "silu":
72
+ raise ValueError(
73
+ f"Unsupported activation: {hidden_act}. "
74
+ "Only silu is supported for now."
75
+ )
76
+ self.act_fn = SiluAndMul()
77
+
78
+ def forward(self, x):
79
+ gate_up, _ = self.gate_up_proj(x)
80
+ x = self.act_fn(gate_up)
81
+ x, _ = self.c_proj(x)
82
+ return x
83
+
84
+
85
+ class ExaoneAttention(nn.Module):
86
+ def __init__(
87
+ self,
88
+ config,
89
+ hidden_size: int,
90
+ num_heads: int,
91
+ num_kv_heads: int,
92
+ layer_id: int = 0,
93
+ rope_theta: float = 500000,
94
+ rope_scaling: Optional[Dict[str, Any]] = None,
95
+ rope_is_neox_style: bool = True,
96
+ max_position_embeddings: int = 4096,
97
+ quant_config: Optional[QuantizationConfig] = None,
98
+ prefix: str = "",
99
+ ) -> None:
100
+ super().__init__()
101
+ self.hidden_size = hidden_size
102
+ tp_size = get_tensor_model_parallel_world_size()
103
+ self.total_num_heads = num_heads
104
+ assert self.total_num_heads % tp_size == 0
105
+ self.num_heads = self.total_num_heads // tp_size
106
+ self.total_num_kv_heads = num_kv_heads
107
+ if self.total_num_kv_heads >= tp_size:
108
+ # Number of KV heads is greater than TP size, so we partition
109
+ # the KV heads across multiple tensor parallel GPUs.
110
+ assert self.total_num_kv_heads % tp_size == 0
111
+ else:
112
+ # Number of KV heads is less than TP size, so we replicate
113
+ # the KV heads across multiple tensor parallel GPUs.
114
+ assert tp_size % self.total_num_kv_heads == 0
115
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
116
+ # MistralConfig has an optional head_dim introduced by Mistral-Nemo
117
+ self.head_dim = getattr(
118
+ config, "head_dim", self.hidden_size // self.total_num_heads
119
+ )
120
+ self.rotary_dim = int(
121
+ self.head_dim * getattr(config, "partial_rotary_factor", 1)
122
+ )
123
+ self.q_size = self.num_heads * self.head_dim
124
+ self.kv_size = self.num_kv_heads * self.head_dim
125
+ self.scaling = self.head_dim**-0.5
126
+ self.rope_theta = rope_theta
127
+ self.max_position_embeddings = max_position_embeddings
128
+
129
+ self.qkv_proj = QKVParallelLinear(
130
+ hidden_size,
131
+ self.head_dim,
132
+ self.total_num_heads,
133
+ self.total_num_kv_heads,
134
+ bias=False,
135
+ quant_config=quant_config,
136
+ prefix=f"{prefix}.qkv_proj",
137
+ )
138
+ self.out_proj = RowParallelLinear(
139
+ self.total_num_heads * self.head_dim,
140
+ hidden_size,
141
+ bias=False,
142
+ quant_config=quant_config,
143
+ prefix=f"{prefix}.out_proj",
144
+ )
145
+
146
+ self.rotary_emb = get_rope(
147
+ self.head_dim,
148
+ rotary_dim=self.rotary_dim,
149
+ max_position=max_position_embeddings,
150
+ base=rope_theta,
151
+ rope_scaling=rope_scaling,
152
+ is_neox_style=rope_is_neox_style,
153
+ )
154
+ self.attn = RadixAttention(
155
+ self.num_heads,
156
+ self.head_dim,
157
+ self.scaling,
158
+ num_kv_heads=self.num_kv_heads,
159
+ layer_id=layer_id,
160
+ )
161
+
162
+ def forward(
163
+ self,
164
+ positions: torch.Tensor,
165
+ hidden_states: torch.Tensor,
166
+ input_metadata: InputMetadata,
167
+ ) -> torch.Tensor:
168
+ qkv, _ = self.qkv_proj(hidden_states)
169
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
170
+ q, k = self.rotary_emb(positions, q, k)
171
+ attn_output = self.attn(q, k, v, input_metadata)
172
+ output, _ = self.out_proj(attn_output)
173
+ return output
174
+
175
+
176
+ class ExaoneDecoderLayer(nn.Module):
177
+ def __init__(
178
+ self,
179
+ config,
180
+ layer_id: int = 0,
181
+ quant_config: Optional[QuantizationConfig] = None,
182
+ prefix: str = "",
183
+ ) -> None:
184
+ super().__init__()
185
+ self.hidden_size = config.hidden_size
186
+ rope_theta = getattr(config, "rope_theta", 500000)
187
+ rope_scaling = getattr(config, "rope_scaling", None)
188
+ if rope_scaling is not None and getattr(
189
+ config, "original_max_position_embeddings", None
190
+ ):
191
+ rope_scaling["original_max_position_embeddings"] = (
192
+ config.original_max_position_embeddings
193
+ )
194
+ rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
195
+ max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
196
+ self.self_attn = ExaoneAttention(
197
+ config=config,
198
+ hidden_size=self.hidden_size,
199
+ num_heads=config.num_attention_heads,
200
+ num_kv_heads=config.num_key_value_heads,
201
+ layer_id=layer_id,
202
+ rope_theta=rope_theta,
203
+ rope_scaling=rope_scaling,
204
+ rope_is_neox_style=rope_is_neox_style,
205
+ max_position_embeddings=max_position_embeddings,
206
+ quant_config=quant_config,
207
+ prefix=f"{prefix}.self_attn",
208
+ )
209
+ self.mlp = ExaoneGatedMLP(
210
+ hidden_size=self.hidden_size,
211
+ intermediate_size=config.intermediate_size,
212
+ hidden_act=config.activation_function,
213
+ quant_config=quant_config,
214
+ prefix=f"{prefix}.mlp",
215
+ )
216
+ rms_norm_eps = config.layer_norm_epsilon
217
+ self.ln_1 = RMSNorm(config.hidden_size, eps=rms_norm_eps)
218
+ self.ln_2 = RMSNorm(config.hidden_size, eps=rms_norm_eps)
219
+
220
+ def forward(
221
+ self,
222
+ positions: torch.Tensor,
223
+ hidden_states: torch.Tensor,
224
+ input_metadata: InputMetadata,
225
+ residual: Optional[torch.Tensor],
226
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
227
+ # Self Attention
228
+ if residual is None:
229
+ residual = hidden_states
230
+ hidden_states = self.ln_1(hidden_states)
231
+ else:
232
+ hidden_states, residual = self.ln_1(hidden_states, residual)
233
+ hidden_states = self.self_attn(
234
+ positions=positions,
235
+ hidden_states=hidden_states,
236
+ input_metadata=input_metadata,
237
+ )
238
+
239
+ # Fully Connected
240
+ hidden_states, residual = self.ln_2(hidden_states, residual)
241
+ hidden_states = self.mlp(hidden_states)
242
+ return hidden_states, residual
243
+
244
+
245
+ class ExaoneModel(nn.Module):
246
+ def __init__(
247
+ self,
248
+ config,
249
+ quant_config: Optional[QuantizationConfig] = None,
250
+ ) -> None:
251
+ super().__init__()
252
+ self.config = config
253
+ self.padding_idx = config.pad_token_id
254
+ self.vocab_size = config.vocab_size
255
+ self.wte = VocabParallelEmbedding(
256
+ config.vocab_size,
257
+ config.hidden_size,
258
+ )
259
+ self.h = nn.ModuleList(
260
+ [
261
+ ExaoneDecoderLayer(
262
+ config, i, quant_config=quant_config, prefix=f"model.h.{i}"
263
+ )
264
+ for i in range(config.num_hidden_layers)
265
+ ]
266
+ )
267
+ rms_norm_eps = config.layer_norm_epsilon
268
+ self.ln_f = RMSNorm(config.hidden_size, eps=rms_norm_eps)
269
+
270
+ def forward(
271
+ self,
272
+ input_ids: torch.Tensor,
273
+ positions: torch.Tensor,
274
+ input_metadata: InputMetadata,
275
+ input_embeds: torch.Tensor = None,
276
+ ) -> torch.Tensor:
277
+ if input_embeds is None:
278
+ hidden_states = self.wte(input_ids)
279
+ else:
280
+ hidden_states = input_embeds
281
+ residual = None
282
+ for i in range(len(self.h)):
283
+ layer = self.h[i]
284
+ hidden_states, residual = layer(
285
+ positions,
286
+ hidden_states,
287
+ input_metadata,
288
+ residual,
289
+ )
290
+ hidden_states, _ = self.ln_f(hidden_states, residual)
291
+ return hidden_states
292
+
293
+
294
+ class ExaoneForCausalLM(nn.Module):
295
+ def __init__(
296
+ self,
297
+ config,
298
+ quant_config: Optional[QuantizationConfig] = None,
299
+ cache_config: Optional[CacheConfig] = None,
300
+ efficient_weight_load=False,
301
+ ) -> None:
302
+ super().__init__()
303
+ self.config = config
304
+ self.quant_config = quant_config
305
+ self.transformer = ExaoneModel(config, quant_config=quant_config)
306
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
307
+ self.logits_processor = LogitsProcessor(config)
308
+ self.sampler = Sampler()
309
+
310
+ @torch.no_grad()
311
+ def forward(
312
+ self,
313
+ input_ids: torch.Tensor,
314
+ positions: torch.Tensor,
315
+ input_metadata: InputMetadata,
316
+ input_embeds: torch.Tensor = None,
317
+ ) -> LogitsProcessorOutput:
318
+ hidden_states = self.transformer(
319
+ input_ids, positions, input_metadata, input_embeds
320
+ )
321
+ logits_output = self.logits_processor(
322
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
323
+ )
324
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
325
+ return sample_output, logits_output
326
+
327
+ def get_module_name(self, name):
328
+ stacked_params_mapping = [
329
+ # (param_name, shard_name, shard_id, num_shard)
330
+ ("qkv_proj", "q_proj", "q", 3),
331
+ ("qkv_proj", "k_proj", "k", 3),
332
+ ("qkv_proj", "v_proj", "v", 3),
333
+ ("gate_up_proj", "c_fc_0", 0, 2),
334
+ ("gate_up_proj", "c_fc_1", 1, 2),
335
+ ]
336
+ for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
337
+ if weight_name in name:
338
+ return (
339
+ name.replace(weight_name, param_name)[: -len(".weight")],
340
+ num_shard,
341
+ )
342
+ return name[: -len(".weight")], 1
343
+
344
+ def get_num_params(self):
345
+ params_dict = dict(self.named_parameters())
346
+ return len(params_dict)
347
+
348
+ def load_weights(
349
+ self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
350
+ ):
351
+ stacked_params_mapping = [
352
+ # (param_name, shard_name, shard_id)
353
+ ("qkv_proj", "q_proj", "q"),
354
+ ("qkv_proj", "k_proj", "k"),
355
+ ("qkv_proj", "v_proj", "v"),
356
+ ("gate_up_proj", "c_fc_0", 0),
357
+ ("gate_up_proj", "c_fc_1", 1),
358
+ ]
359
+ params_dict = dict(self.named_parameters())
360
+
361
+ def load_weights_per_param(name, loaded_weight):
362
+ if "rotary_emb.inv_freq" in name or "projector" in name:
363
+ return
364
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
365
+ # Models trained using ColossalAI may include these tensors in
366
+ # the checkpoint. Skip them.
367
+ return
368
+ if name.startswith("model.vision_tower") and name not in params_dict:
369
+ return
370
+
371
+ for param_name, weight_name, shard_id in stacked_params_mapping:
372
+ if weight_name not in name:
373
+ continue
374
+ name = name.replace(weight_name, param_name)
375
+ # Skip loading extra bias for GPTQ models.
376
+ if name.endswith(".bias") and name not in params_dict:
377
+ continue
378
+ param = params_dict[name]
379
+ weight_loader = param.weight_loader
380
+ weight_loader(param, loaded_weight, shard_id)
381
+ break
382
+ else:
383
+ # Skip loading extra bias for GPTQ models.
384
+ if name.endswith(".bias") and name not in params_dict:
385
+ return
386
+ param = params_dict[name]
387
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
388
+ weight_loader(param, loaded_weight)
389
+
390
+ if name is None or loaded_weight is None:
391
+ for name, loaded_weight in weights:
392
+ name = name.replace("attn.attention", "self_attn")
393
+ load_weights_per_param(name, loaded_weight)
394
+ else:
395
+ name = name.replace("attn.attention", "self_attn")
396
+ load_weights_per_param(name, loaded_weight)
397
+
398
+
399
+ EntryClass = ExaoneForCausalLM
@@ -37,6 +37,7 @@ from sglang.srt.layers.activation import GeluAndMul
37
37
  from sglang.srt.layers.layernorm import RMSNorm
38
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
+ from sglang.srt.layers.sampler import Sampler
40
41
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
41
42
 
42
43
 
@@ -287,6 +288,7 @@ class GemmaForCausalLM(nn.Module):
287
288
  self.quant_config = quant_config
288
289
  self.model = GemmaModel(config, quant_config=quant_config)
289
290
  self.logits_processor = LogitsProcessor(config)
291
+ self.sampler = Sampler()
290
292
 
291
293
  @torch.no_grad()
292
294
  def forward(
@@ -297,9 +299,11 @@ class GemmaForCausalLM(nn.Module):
297
299
  input_embeds: torch.Tensor = None,
298
300
  ) -> torch.Tensor:
299
301
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
300
- return self.logits_processor(
302
+ logits_output = self.logits_processor(
301
303
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
302
304
  )
305
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
306
+ return (sample_output, logits_output)
303
307
 
304
308
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
305
309
  stacked_params_mapping = [
@@ -37,6 +37,7 @@ from sglang.srt.layers.activation import GeluAndMul
37
37
  from sglang.srt.layers.layernorm import GemmaRMSNorm
38
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
+ from sglang.srt.layers.sampler import Sampler
40
41
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
41
42
 
42
43
 
@@ -346,6 +347,7 @@ class Gemma2ForCausalLM(nn.Module):
346
347
  self.quant_config = quant_config
347
348
  self.model = Gemma2Model(config, cache_config, quant_config)
348
349
  self.logits_processor = LogitsProcessor(config)
350
+ self.sampler = Sampler()
349
351
 
350
352
  @torch.no_grad()
351
353
  def forward(
@@ -356,9 +358,11 @@ class Gemma2ForCausalLM(nn.Module):
356
358
  input_embeds: torch.Tensor = None,
357
359
  ) -> torch.Tensor:
358
360
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
359
- return self.logits_processor(
361
+ logits_output = self.logits_processor(
360
362
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
361
363
  )
364
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
365
+ return sample_output, logits_output
362
366
 
363
367
  def get_attention_sliding_window_size(self):
364
368
  return get_attention_sliding_window_size(self.config)
@@ -35,6 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
35
  from sglang.srt.layers.activation import get_act_fn
36
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
+ from sglang.srt.layers.sampler import Sampler
38
39
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
39
40
 
40
41
 
@@ -261,6 +262,7 @@ class GPTBigCodeForCausalLM(nn.Module):
261
262
  if lora_config:
262
263
  self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
263
264
  self.logits_processor = LogitsProcessor(config)
265
+ self.sampler = Sampler()
264
266
 
265
267
  @torch.no_grad()
266
268
  def forward(
@@ -270,9 +272,11 @@ class GPTBigCodeForCausalLM(nn.Module):
270
272
  input_metadata: InputMetadata,
271
273
  ) -> torch.Tensor:
272
274
  hidden_states = self.transformer(input_ids, positions, input_metadata)
273
- return self.logits_processor(
275
+ logits_output = self.logits_processor(
274
276
  input_ids, hidden_states, self.lm_head.weight, input_metadata
275
277
  )
278
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
279
+ return sample_output, logits_output
276
280
 
277
281
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
278
282
  params_dict = dict(self.named_parameters(remove_duplicate=False))
sglang/srt/models/grok.py CHANGED
@@ -46,6 +46,7 @@ from sglang.srt.layers.fused_moe import FusedMoE
46
46
  from sglang.srt.layers.layernorm import RMSNorm
47
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
+ from sglang.srt.layers.sampler import Sampler
49
50
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
51
 
51
52
 
@@ -297,6 +298,7 @@ class Grok1ForCausalLM(nn.Module):
297
298
  self.model = Grok1Model(config, quant_config=quant_config)
298
299
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
299
300
  self.logits_processor = LogitsProcessor(config)
301
+ self.sampler = Sampler()
300
302
 
301
303
  # Monkey patch _prepare_weights to load pre-sharded weights
302
304
  setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
@@ -313,9 +315,11 @@ class Grok1ForCausalLM(nn.Module):
313
315
  input_embeds: torch.Tensor = None,
314
316
  ) -> torch.Tensor:
315
317
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
316
- return self.logits_processor(
318
+ logits_output = self.logits_processor(
317
319
  input_ids, hidden_states, self.lm_head.weight, input_metadata
318
320
  )
321
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
322
+ return sample_output, logits_output
319
323
 
320
324
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
321
325
  stacked_params_mapping = [
@@ -40,6 +40,7 @@ from sglang.srt.layers.activation import SiluAndMul
40
40
  from sglang.srt.layers.layernorm import RMSNorm
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
+ from sglang.srt.layers.sampler import Sampler
43
44
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
44
45
 
45
46
 
@@ -262,6 +263,7 @@ class InternLM2ForCausalLM(nn.Module):
262
263
  self.model = InternLM2Model(config, quant_config)
263
264
  self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
264
265
  self.logits_processor = LogitsProcessor(config)
266
+ self.sampler = Sampler()
265
267
 
266
268
  @torch.no_grad()
267
269
  def forward(
@@ -272,9 +274,11 @@ class InternLM2ForCausalLM(nn.Module):
272
274
  input_embeds: torch.Tensor = None,
273
275
  ) -> torch.Tensor:
274
276
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
275
- return self.logits_processor(
277
+ logits_output = self.logits_processor(
276
278
  input_ids, hidden_states, self.output.weight, input_metadata
277
279
  )
280
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
281
+ return sample_output, logits_output
278
282
 
279
283
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
280
284
  stacked_params_mapping = [
@@ -39,8 +39,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
39
 
40
40
  from sglang.srt.layers.activation import SiluAndMul
41
41
  from sglang.srt.layers.layernorm import RMSNorm
42
- from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor
42
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
+ from sglang.srt.layers.sampler import Sampler
44
45
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
46
 
46
47
 
@@ -302,6 +303,7 @@ class LlamaForCausalLM(nn.Module):
302
303
  self.model = LlamaModel(config, quant_config=quant_config)
303
304
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
304
305
  self.logits_processor = LogitsProcessor(config)
306
+ self.sampler = Sampler()
305
307
 
306
308
  @torch.no_grad()
307
309
  def forward(
@@ -310,11 +312,13 @@ class LlamaForCausalLM(nn.Module):
310
312
  positions: torch.Tensor,
311
313
  input_metadata: InputMetadata,
312
314
  input_embeds: torch.Tensor = None,
313
- ) -> LogitProcessorOutput:
315
+ ) -> LogitsProcessorOutput:
314
316
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
315
- return self.logits_processor(
317
+ logits_output = self.logits_processor(
316
318
  input_ids, hidden_states, self.lm_head.weight, input_metadata
317
319
  )
320
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
321
+ return sample_output, logits_output
318
322
 
319
323
  def get_module_name(self, name):
320
324
  stacked_params_mapping = [
@@ -24,7 +24,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
24
24
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
25
25
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
26
 
27
- from sglang.srt.layers.logits_processor import LogitProcessorOutput
27
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
28
28
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
29
29
  from sglang.srt.models.llama2 import LlamaModel
30
30
 
@@ -65,7 +65,7 @@ class LlamaForClassification(nn.Module):
65
65
  (input_metadata.batch_size, self.config.classification_out_size)
66
66
  ).to(input_ids.device)
67
67
 
68
- return LogitProcessorOutput(
68
+ return LogitsProcessorOutput(
69
69
  next_token_logits=scores,
70
70
  next_token_logprobs=scores,
71
71
  normalized_prompt_logprobs=scores,
@@ -39,6 +39,7 @@ from sglang.srt.layers.activation import SiluAndMul
39
39
  from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
+ from sglang.srt.layers.sampler import Sampler
42
43
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
43
44
 
44
45
 
@@ -297,6 +298,7 @@ class MiniCPMForCausalLM(nn.Module):
297
298
  self.scale_width = self.config.hidden_size / self.config.dim_model_base
298
299
 
299
300
  self.logits_processor = LogitsProcessor(config)
301
+ self.sampler = Sampler()
300
302
 
301
303
  @torch.no_grad()
302
304
  def forward(
@@ -314,9 +316,11 @@ class MiniCPMForCausalLM(nn.Module):
314
316
  lm_head_weight = self.model.embed_tokens.weight
315
317
  else:
316
318
  lm_head_weight = self.lm_head.weight
317
- return self.logits_processor(
319
+ logits_output = self.logits_processor(
318
320
  input_ids, hidden_states, lm_head_weight, input_metadata
319
321
  )
322
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
323
+ return sample_output, logits_output
320
324
 
321
325
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
322
326
  stacked_params_mapping = [
@@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
41
  from sglang.srt.layers.layernorm import RMSNorm
42
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
+ from sglang.srt.layers.sampler import Sampler
44
45
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
46
 
46
47
 
@@ -299,6 +300,7 @@ class MixtralForCausalLM(nn.Module):
299
300
  self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
300
301
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
301
302
  self.logits_processor = LogitsProcessor(config)
303
+ self.sampler = Sampler()
302
304
 
303
305
  def forward(
304
306
  self,
@@ -308,9 +310,11 @@ class MixtralForCausalLM(nn.Module):
308
310
  input_embeds: torch.Tensor = None,
309
311
  ) -> torch.Tensor:
310
312
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
311
- return self.logits_processor(
313
+ logits_output = self.logits_processor(
312
314
  input_ids, hidden_states, self.lm_head.weight, input_metadata
313
315
  )
316
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
317
+ return sample_output, logits_output
314
318
 
315
319
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
316
320
  stacked_params_mapping = [
@@ -358,7 +362,7 @@ class MixtralForCausalLM(nn.Module):
358
362
  weight_loader(
359
363
  param,
360
364
  loaded_weight,
361
- weight_name,
365
+ name,
362
366
  shard_id=shard_id,
363
367
  expert_id=expert_id,
364
368
  )