sglang 0.2.14.post1__py3-none-any.whl → 0.2.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. sglang/api.py +2 -0
  2. sglang/bench_latency.py +39 -28
  3. sglang/lang/interpreter.py +3 -0
  4. sglang/lang/ir.py +5 -0
  5. sglang/launch_server_llavavid.py +26 -0
  6. sglang/srt/configs/__init__.py +5 -0
  7. sglang/srt/configs/exaone.py +195 -0
  8. sglang/srt/constrained/fsm_cache.py +1 -1
  9. sglang/srt/conversation.py +24 -2
  10. sglang/srt/hf_transformers_utils.py +11 -160
  11. sglang/srt/layers/activation.py +10 -4
  12. sglang/srt/layers/extend_attention.py +13 -8
  13. sglang/srt/layers/layernorm.py +47 -1
  14. sglang/srt/layers/logits_processor.py +4 -4
  15. sglang/srt/layers/sampler.py +69 -16
  16. sglang/srt/managers/controller_multi.py +5 -5
  17. sglang/srt/managers/controller_single.py +5 -5
  18. sglang/srt/managers/io_struct.py +11 -5
  19. sglang/srt/managers/schedule_batch.py +25 -13
  20. sglang/srt/managers/tokenizer_manager.py +76 -63
  21. sglang/srt/managers/tp_worker.py +47 -36
  22. sglang/srt/model_config.py +3 -3
  23. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  24. sglang/srt/model_executor/forward_batch_info.py +78 -43
  25. sglang/srt/model_executor/model_runner.py +29 -18
  26. sglang/srt/models/chatglm.py +5 -13
  27. sglang/srt/models/commandr.py +5 -1
  28. sglang/srt/models/dbrx.py +5 -1
  29. sglang/srt/models/deepseek.py +5 -1
  30. sglang/srt/models/deepseek_v2.py +57 -25
  31. sglang/srt/models/exaone.py +399 -0
  32. sglang/srt/models/gemma.py +7 -3
  33. sglang/srt/models/gemma2.py +6 -52
  34. sglang/srt/models/gpt_bigcode.py +5 -1
  35. sglang/srt/models/grok.py +14 -4
  36. sglang/srt/models/internlm2.py +5 -1
  37. sglang/srt/models/llama2.py +10 -7
  38. sglang/srt/models/llama_classification.py +2 -6
  39. sglang/srt/models/llama_embedding.py +3 -4
  40. sglang/srt/models/llava.py +69 -91
  41. sglang/srt/models/llavavid.py +40 -86
  42. sglang/srt/models/minicpm.py +5 -1
  43. sglang/srt/models/mixtral.py +6 -2
  44. sglang/srt/models/mixtral_quant.py +5 -1
  45. sglang/srt/models/qwen.py +5 -2
  46. sglang/srt/models/qwen2.py +9 -6
  47. sglang/srt/models/qwen2_moe.py +12 -33
  48. sglang/srt/models/stablelm.py +5 -1
  49. sglang/srt/models/yivl.py +2 -7
  50. sglang/srt/openai_api/adapter.py +16 -1
  51. sglang/srt/openai_api/protocol.py +5 -5
  52. sglang/srt/sampling/sampling_batch_info.py +79 -6
  53. sglang/srt/server.py +9 -9
  54. sglang/srt/utils.py +18 -36
  55. sglang/test/runners.py +2 -2
  56. sglang/test/test_layernorm.py +53 -1
  57. sglang/version.py +1 -1
  58. {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/METADATA +8 -8
  59. sglang-0.2.15.dist-info/RECORD +118 -0
  60. sglang-0.2.14.post1.dist-info/RECORD +0 -114
  61. {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/LICENSE +0 -0
  62. {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/WHEEL +0 -0
  63. {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,7 @@ limitations under the License.
19
19
  from typing import Any, Dict, Iterable, Optional, Tuple
20
20
 
21
21
  import torch
22
+ from flashinfer import bmm_fp8
22
23
  from torch import nn
23
24
  from transformers import PretrainedConfig
24
25
  from vllm.config import CacheConfig
@@ -45,6 +46,7 @@ from sglang.srt.layers.activation import SiluAndMul
45
46
  from sglang.srt.layers.layernorm import RMSNorm
46
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
+ from sglang.srt.layers.sampler import Sampler
48
50
  from sglang.srt.managers.schedule_batch import global_server_args_dict
49
51
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
52
 
@@ -160,6 +162,15 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
160
162
  return 0.1 * mscale * math.log(scale) + 1.0
161
163
 
162
164
 
165
+ def input_to_float8(x, dtype=torch.float8_e4m3fn):
166
+ finfo = torch.finfo(dtype)
167
+ min_val, max_val = x.aminmax()
168
+ amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
169
+ scale = finfo.max / amax
170
+ x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
171
+ return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
172
+
173
+
163
174
  class DeepseekV2Attention(nn.Module):
164
175
 
165
176
  def __init__(
@@ -254,11 +265,6 @@ class DeepseekV2Attention(nn.Module):
254
265
  mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
255
266
  self.scaling = self.scaling * mscale * mscale
256
267
 
257
- # self.attn = Attention(self.num_heads,
258
- # self.qk_head_dim,
259
- # self.scaling,
260
- # num_kv_heads=self.num_heads)
261
-
262
268
  # TODO, support head_size 192
263
269
  self.attn = RadixAttention(
264
270
  self.num_local_heads,
@@ -282,7 +288,7 @@ class DeepseekV2Attention(nn.Module):
282
288
  q = self.q_proj(hidden_states)[0].view(
283
289
  -1, self.num_local_heads, self.qk_head_dim
284
290
  )
285
- q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
291
+ _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
286
292
  latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
287
293
  kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
288
294
  latent_cache = latent_cache.unsqueeze(1)
@@ -416,12 +422,9 @@ class DeepseekV2AttentionMLA(nn.Module):
416
422
  v_head_dim=self.kv_lora_rank,
417
423
  )
418
424
 
419
- kv_b_proj = self.kv_b_proj
420
- w_kc, w_vc = kv_b_proj.weight.unflatten(
421
- 0, (-1, qk_nope_head_dim + v_head_dim)
422
- ).split([qk_nope_head_dim, v_head_dim], dim=1)
423
- self.w_kc = w_kc
424
- self.w_vc = w_vc
425
+ self.w_kc = None
426
+ self.w_vc = None
427
+ self.w_scale = None
425
428
 
426
429
  def forward(
427
430
  self,
@@ -442,8 +445,17 @@ class DeepseekV2AttentionMLA(nn.Module):
442
445
  -1, self.num_local_heads, self.qk_head_dim
443
446
  )
444
447
  q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
445
- q_nope_out = q_input[..., : self.kv_lora_rank]
446
- torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1))
448
+
449
+ if self.w_kc.dtype == torch.float8_e4m3fn:
450
+ q_nope_val, q_nope_scale = input_to_float8(
451
+ q_nope.transpose(0, 1), torch.float8_e4m3fn
452
+ )
453
+ q_nope_out = bmm_fp8(
454
+ q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
455
+ )
456
+ else:
457
+ q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
458
+ q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
447
459
 
448
460
  latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
449
461
  v_input = latent_cache[..., : self.kv_lora_rank]
@@ -458,16 +470,21 @@ class DeepseekV2AttentionMLA(nn.Module):
458
470
 
459
471
  attn_output = self.attn(q_input, k_input, v_input, input_metadata)
460
472
  attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
461
- attn_bmm_output = attn_output.new_empty(
462
- q_len, self.num_local_heads, self.v_head_dim
463
- )
464
- torch.bmm(
465
- attn_output.transpose(0, 1),
466
- self.w_vc.transpose(1, 2).contiguous(),
467
- out=attn_bmm_output.transpose(0, 1),
468
- )
469
473
 
470
- attn_output = attn_bmm_output.flatten(1, 2)
474
+ if self.w_vc.dtype == torch.float8_e4m3fn:
475
+ attn_output_val, attn_output_scale = input_to_float8(
476
+ attn_output.transpose(0, 1), torch.float8_e4m3fn
477
+ )
478
+ attn_bmm_output = bmm_fp8(
479
+ attn_output_val,
480
+ self.w_vc,
481
+ attn_output_scale,
482
+ self.w_scale,
483
+ torch.bfloat16,
484
+ )
485
+ else:
486
+ attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
487
+ attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
471
488
  output, _ = self.o_proj(attn_output)
472
489
 
473
490
  return output
@@ -632,6 +649,7 @@ class DeepseekV2ForCausalLM(nn.Module):
632
649
  config.vocab_size, config.hidden_size, quant_config=quant_config
633
650
  )
634
651
  self.logits_processor = LogitsProcessor(config)
652
+ self.sampler = Sampler()
635
653
 
636
654
  def forward(
637
655
  self,
@@ -640,9 +658,11 @@ class DeepseekV2ForCausalLM(nn.Module):
640
658
  input_metadata: InputMetadata,
641
659
  ) -> torch.Tensor:
642
660
  hidden_states = self.model(input_ids, positions, input_metadata)
643
- return self.logits_processor(
661
+ logits_output = self.logits_processor(
644
662
  input_ids, hidden_states, self.lm_head.weight, input_metadata
645
663
  )
664
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
665
+ return sample_output, logits_output
646
666
 
647
667
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
648
668
  stacked_params_mapping = [
@@ -695,7 +715,7 @@ class DeepseekV2ForCausalLM(nn.Module):
695
715
  weight_loader(
696
716
  param,
697
717
  loaded_weight,
698
- weight_name,
718
+ name,
699
719
  shard_id=shard_id,
700
720
  expert_id=expert_id,
701
721
  )
@@ -711,5 +731,17 @@ class DeepseekV2ForCausalLM(nn.Module):
711
731
  )
712
732
  weight_loader(param, loaded_weight)
713
733
 
734
+ if global_server_args_dict["enable_mla"]:
735
+ for layer_id in range(self.config.num_hidden_layers):
736
+ self_attn = self.model.layers[layer_id].self_attn
737
+ w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
738
+ 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
739
+ ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
740
+ self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
741
+ self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
742
+ if hasattr(self_attn.kv_b_proj, "weight_scale"):
743
+ self_attn.w_scale = self_attn.kv_b_proj.weight_scale
744
+ del self_attn.kv_b_proj
745
+
714
746
 
715
747
  EntryClass = DeepseekV2ForCausalLM
@@ -0,0 +1,399 @@
1
+ """
2
+ Copyright 2024 The LGcns AI Engineering Team
3
+ Copyright 2023-2024 SGLang Team
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ # Adapted from llama2.py
18
+ """Inference-only Exaone model compatible with HuggingFace weights."""
19
+
20
+ from typing import Any, Dict, Iterable, Optional, Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+ from vllm.config import CacheConfig
25
+ from vllm.distributed import get_tensor_model_parallel_world_size
26
+ from vllm.model_executor.layers.linear import (
27
+ MergedColumnParallelLinear,
28
+ QKVParallelLinear,
29
+ RowParallelLinear,
30
+ )
31
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
32
+ from vllm.model_executor.layers.rotary_embedding import get_rope
33
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
34
+ ParallelLMHead,
35
+ VocabParallelEmbedding,
36
+ )
37
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
38
+
39
+ from sglang.srt.layers.activation import SiluAndMul
40
+ from sglang.srt.layers.layernorm import RMSNorm
41
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
42
+ from sglang.srt.layers.radix_attention import RadixAttention
43
+ from sglang.srt.layers.sampler import Sampler
44
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
+
46
+
47
+ class ExaoneGatedMLP(nn.Module):
48
+ def __init__(
49
+ self,
50
+ hidden_size: int,
51
+ intermediate_size: int,
52
+ hidden_act: str,
53
+ quant_config: Optional[QuantizationConfig] = None,
54
+ prefix: str = "",
55
+ ) -> None:
56
+ super().__init__()
57
+ self.gate_up_proj = MergedColumnParallelLinear(
58
+ hidden_size,
59
+ [intermediate_size] * 2,
60
+ bias=False,
61
+ quant_config=quant_config,
62
+ prefix=f"{prefix}.gate_up_proj",
63
+ )
64
+ self.c_proj = RowParallelLinear(
65
+ intermediate_size,
66
+ hidden_size,
67
+ bias=False,
68
+ quant_config=quant_config,
69
+ prefix=f"{prefix}.c_proj",
70
+ )
71
+ if hidden_act != "silu":
72
+ raise ValueError(
73
+ f"Unsupported activation: {hidden_act}. "
74
+ "Only silu is supported for now."
75
+ )
76
+ self.act_fn = SiluAndMul()
77
+
78
+ def forward(self, x):
79
+ gate_up, _ = self.gate_up_proj(x)
80
+ x = self.act_fn(gate_up)
81
+ x, _ = self.c_proj(x)
82
+ return x
83
+
84
+
85
+ class ExaoneAttention(nn.Module):
86
+ def __init__(
87
+ self,
88
+ config,
89
+ hidden_size: int,
90
+ num_heads: int,
91
+ num_kv_heads: int,
92
+ layer_id: int = 0,
93
+ rope_theta: float = 500000,
94
+ rope_scaling: Optional[Dict[str, Any]] = None,
95
+ rope_is_neox_style: bool = True,
96
+ max_position_embeddings: int = 4096,
97
+ quant_config: Optional[QuantizationConfig] = None,
98
+ prefix: str = "",
99
+ ) -> None:
100
+ super().__init__()
101
+ self.hidden_size = hidden_size
102
+ tp_size = get_tensor_model_parallel_world_size()
103
+ self.total_num_heads = num_heads
104
+ assert self.total_num_heads % tp_size == 0
105
+ self.num_heads = self.total_num_heads // tp_size
106
+ self.total_num_kv_heads = num_kv_heads
107
+ if self.total_num_kv_heads >= tp_size:
108
+ # Number of KV heads is greater than TP size, so we partition
109
+ # the KV heads across multiple tensor parallel GPUs.
110
+ assert self.total_num_kv_heads % tp_size == 0
111
+ else:
112
+ # Number of KV heads is less than TP size, so we replicate
113
+ # the KV heads across multiple tensor parallel GPUs.
114
+ assert tp_size % self.total_num_kv_heads == 0
115
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
116
+ # MistralConfig has an optional head_dim introduced by Mistral-Nemo
117
+ self.head_dim = getattr(
118
+ config, "head_dim", self.hidden_size // self.total_num_heads
119
+ )
120
+ self.rotary_dim = int(
121
+ self.head_dim * getattr(config, "partial_rotary_factor", 1)
122
+ )
123
+ self.q_size = self.num_heads * self.head_dim
124
+ self.kv_size = self.num_kv_heads * self.head_dim
125
+ self.scaling = self.head_dim**-0.5
126
+ self.rope_theta = rope_theta
127
+ self.max_position_embeddings = max_position_embeddings
128
+
129
+ self.qkv_proj = QKVParallelLinear(
130
+ hidden_size,
131
+ self.head_dim,
132
+ self.total_num_heads,
133
+ self.total_num_kv_heads,
134
+ bias=False,
135
+ quant_config=quant_config,
136
+ prefix=f"{prefix}.qkv_proj",
137
+ )
138
+ self.out_proj = RowParallelLinear(
139
+ self.total_num_heads * self.head_dim,
140
+ hidden_size,
141
+ bias=False,
142
+ quant_config=quant_config,
143
+ prefix=f"{prefix}.out_proj",
144
+ )
145
+
146
+ self.rotary_emb = get_rope(
147
+ self.head_dim,
148
+ rotary_dim=self.rotary_dim,
149
+ max_position=max_position_embeddings,
150
+ base=rope_theta,
151
+ rope_scaling=rope_scaling,
152
+ is_neox_style=rope_is_neox_style,
153
+ )
154
+ self.attn = RadixAttention(
155
+ self.num_heads,
156
+ self.head_dim,
157
+ self.scaling,
158
+ num_kv_heads=self.num_kv_heads,
159
+ layer_id=layer_id,
160
+ )
161
+
162
+ def forward(
163
+ self,
164
+ positions: torch.Tensor,
165
+ hidden_states: torch.Tensor,
166
+ input_metadata: InputMetadata,
167
+ ) -> torch.Tensor:
168
+ qkv, _ = self.qkv_proj(hidden_states)
169
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
170
+ q, k = self.rotary_emb(positions, q, k)
171
+ attn_output = self.attn(q, k, v, input_metadata)
172
+ output, _ = self.out_proj(attn_output)
173
+ return output
174
+
175
+
176
+ class ExaoneDecoderLayer(nn.Module):
177
+ def __init__(
178
+ self,
179
+ config,
180
+ layer_id: int = 0,
181
+ quant_config: Optional[QuantizationConfig] = None,
182
+ prefix: str = "",
183
+ ) -> None:
184
+ super().__init__()
185
+ self.hidden_size = config.hidden_size
186
+ rope_theta = getattr(config, "rope_theta", 500000)
187
+ rope_scaling = getattr(config, "rope_scaling", None)
188
+ if rope_scaling is not None and getattr(
189
+ config, "original_max_position_embeddings", None
190
+ ):
191
+ rope_scaling["original_max_position_embeddings"] = (
192
+ config.original_max_position_embeddings
193
+ )
194
+ rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
195
+ max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
196
+ self.self_attn = ExaoneAttention(
197
+ config=config,
198
+ hidden_size=self.hidden_size,
199
+ num_heads=config.num_attention_heads,
200
+ num_kv_heads=config.num_key_value_heads,
201
+ layer_id=layer_id,
202
+ rope_theta=rope_theta,
203
+ rope_scaling=rope_scaling,
204
+ rope_is_neox_style=rope_is_neox_style,
205
+ max_position_embeddings=max_position_embeddings,
206
+ quant_config=quant_config,
207
+ prefix=f"{prefix}.self_attn",
208
+ )
209
+ self.mlp = ExaoneGatedMLP(
210
+ hidden_size=self.hidden_size,
211
+ intermediate_size=config.intermediate_size,
212
+ hidden_act=config.activation_function,
213
+ quant_config=quant_config,
214
+ prefix=f"{prefix}.mlp",
215
+ )
216
+ rms_norm_eps = config.layer_norm_epsilon
217
+ self.ln_1 = RMSNorm(config.hidden_size, eps=rms_norm_eps)
218
+ self.ln_2 = RMSNorm(config.hidden_size, eps=rms_norm_eps)
219
+
220
+ def forward(
221
+ self,
222
+ positions: torch.Tensor,
223
+ hidden_states: torch.Tensor,
224
+ input_metadata: InputMetadata,
225
+ residual: Optional[torch.Tensor],
226
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
227
+ # Self Attention
228
+ if residual is None:
229
+ residual = hidden_states
230
+ hidden_states = self.ln_1(hidden_states)
231
+ else:
232
+ hidden_states, residual = self.ln_1(hidden_states, residual)
233
+ hidden_states = self.self_attn(
234
+ positions=positions,
235
+ hidden_states=hidden_states,
236
+ input_metadata=input_metadata,
237
+ )
238
+
239
+ # Fully Connected
240
+ hidden_states, residual = self.ln_2(hidden_states, residual)
241
+ hidden_states = self.mlp(hidden_states)
242
+ return hidden_states, residual
243
+
244
+
245
+ class ExaoneModel(nn.Module):
246
+ def __init__(
247
+ self,
248
+ config,
249
+ quant_config: Optional[QuantizationConfig] = None,
250
+ ) -> None:
251
+ super().__init__()
252
+ self.config = config
253
+ self.padding_idx = config.pad_token_id
254
+ self.vocab_size = config.vocab_size
255
+ self.wte = VocabParallelEmbedding(
256
+ config.vocab_size,
257
+ config.hidden_size,
258
+ )
259
+ self.h = nn.ModuleList(
260
+ [
261
+ ExaoneDecoderLayer(
262
+ config, i, quant_config=quant_config, prefix=f"model.h.{i}"
263
+ )
264
+ for i in range(config.num_hidden_layers)
265
+ ]
266
+ )
267
+ rms_norm_eps = config.layer_norm_epsilon
268
+ self.ln_f = RMSNorm(config.hidden_size, eps=rms_norm_eps)
269
+
270
+ def forward(
271
+ self,
272
+ input_ids: torch.Tensor,
273
+ positions: torch.Tensor,
274
+ input_metadata: InputMetadata,
275
+ input_embeds: torch.Tensor = None,
276
+ ) -> torch.Tensor:
277
+ if input_embeds is None:
278
+ hidden_states = self.wte(input_ids)
279
+ else:
280
+ hidden_states = input_embeds
281
+ residual = None
282
+ for i in range(len(self.h)):
283
+ layer = self.h[i]
284
+ hidden_states, residual = layer(
285
+ positions,
286
+ hidden_states,
287
+ input_metadata,
288
+ residual,
289
+ )
290
+ hidden_states, _ = self.ln_f(hidden_states, residual)
291
+ return hidden_states
292
+
293
+
294
+ class ExaoneForCausalLM(nn.Module):
295
+ def __init__(
296
+ self,
297
+ config,
298
+ quant_config: Optional[QuantizationConfig] = None,
299
+ cache_config: Optional[CacheConfig] = None,
300
+ efficient_weight_load=False,
301
+ ) -> None:
302
+ super().__init__()
303
+ self.config = config
304
+ self.quant_config = quant_config
305
+ self.transformer = ExaoneModel(config, quant_config=quant_config)
306
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
307
+ self.logits_processor = LogitsProcessor(config)
308
+ self.sampler = Sampler()
309
+
310
+ @torch.no_grad()
311
+ def forward(
312
+ self,
313
+ input_ids: torch.Tensor,
314
+ positions: torch.Tensor,
315
+ input_metadata: InputMetadata,
316
+ input_embeds: torch.Tensor = None,
317
+ ) -> LogitsProcessorOutput:
318
+ hidden_states = self.transformer(
319
+ input_ids, positions, input_metadata, input_embeds
320
+ )
321
+ logits_output = self.logits_processor(
322
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
323
+ )
324
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
325
+ return sample_output, logits_output
326
+
327
+ def get_module_name(self, name):
328
+ stacked_params_mapping = [
329
+ # (param_name, shard_name, shard_id, num_shard)
330
+ ("qkv_proj", "q_proj", "q", 3),
331
+ ("qkv_proj", "k_proj", "k", 3),
332
+ ("qkv_proj", "v_proj", "v", 3),
333
+ ("gate_up_proj", "c_fc_0", 0, 2),
334
+ ("gate_up_proj", "c_fc_1", 1, 2),
335
+ ]
336
+ for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
337
+ if weight_name in name:
338
+ return (
339
+ name.replace(weight_name, param_name)[: -len(".weight")],
340
+ num_shard,
341
+ )
342
+ return name[: -len(".weight")], 1
343
+
344
+ def get_num_params(self):
345
+ params_dict = dict(self.named_parameters())
346
+ return len(params_dict)
347
+
348
+ def load_weights(
349
+ self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
350
+ ):
351
+ stacked_params_mapping = [
352
+ # (param_name, shard_name, shard_id)
353
+ ("qkv_proj", "q_proj", "q"),
354
+ ("qkv_proj", "k_proj", "k"),
355
+ ("qkv_proj", "v_proj", "v"),
356
+ ("gate_up_proj", "c_fc_0", 0),
357
+ ("gate_up_proj", "c_fc_1", 1),
358
+ ]
359
+ params_dict = dict(self.named_parameters())
360
+
361
+ def load_weights_per_param(name, loaded_weight):
362
+ if "rotary_emb.inv_freq" in name or "projector" in name:
363
+ return
364
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
365
+ # Models trained using ColossalAI may include these tensors in
366
+ # the checkpoint. Skip them.
367
+ return
368
+ if name.startswith("model.vision_tower") and name not in params_dict:
369
+ return
370
+
371
+ for param_name, weight_name, shard_id in stacked_params_mapping:
372
+ if weight_name not in name:
373
+ continue
374
+ name = name.replace(weight_name, param_name)
375
+ # Skip loading extra bias for GPTQ models.
376
+ if name.endswith(".bias") and name not in params_dict:
377
+ continue
378
+ param = params_dict[name]
379
+ weight_loader = param.weight_loader
380
+ weight_loader(param, loaded_weight, shard_id)
381
+ break
382
+ else:
383
+ # Skip loading extra bias for GPTQ models.
384
+ if name.endswith(".bias") and name not in params_dict:
385
+ return
386
+ param = params_dict[name]
387
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
388
+ weight_loader(param, loaded_weight)
389
+
390
+ if name is None or loaded_weight is None:
391
+ for name, loaded_weight in weights:
392
+ name = name.replace("attn.attention", "self_attn")
393
+ load_weights_per_param(name, loaded_weight)
394
+ else:
395
+ name = name.replace("attn.attention", "self_attn")
396
+ load_weights_per_param(name, loaded_weight)
397
+
398
+
399
+ EntryClass = ExaoneForCausalLM
@@ -23,7 +23,6 @@ from torch import nn
23
23
  from transformers import PretrainedConfig
24
24
  from vllm.config import CacheConfig, LoRAConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
- from vllm.model_executor.layers.activation import GeluAndMul
27
26
  from vllm.model_executor.layers.linear import (
28
27
  MergedColumnParallelLinear,
29
28
  QKVParallelLinear,
@@ -34,9 +33,11 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
34
33
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
35
34
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
36
35
 
36
+ from sglang.srt.layers.activation import GeluAndMul
37
37
  from sglang.srt.layers.layernorm import RMSNorm
38
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
+ from sglang.srt.layers.sampler import Sampler
40
41
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
41
42
 
42
43
 
@@ -60,7 +61,7 @@ class GemmaMLP(nn.Module):
60
61
  bias=False,
61
62
  quant_config=quant_config,
62
63
  )
63
- self.act_fn = GeluAndMul()
64
+ self.act_fn = GeluAndMul("none")
64
65
 
65
66
  def forward(self, x):
66
67
  gate_up, _ = self.gate_up_proj(x)
@@ -287,6 +288,7 @@ class GemmaForCausalLM(nn.Module):
287
288
  self.quant_config = quant_config
288
289
  self.model = GemmaModel(config, quant_config=quant_config)
289
290
  self.logits_processor = LogitsProcessor(config)
291
+ self.sampler = Sampler()
290
292
 
291
293
  @torch.no_grad()
292
294
  def forward(
@@ -297,9 +299,11 @@ class GemmaForCausalLM(nn.Module):
297
299
  input_embeds: torch.Tensor = None,
298
300
  ) -> torch.Tensor:
299
301
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
300
- return self.logits_processor(
302
+ logits_output = self.logits_processor(
301
303
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
302
304
  )
305
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
306
+ return (sample_output, logits_output)
303
307
 
304
308
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
305
309
  stacked_params_mapping = [