sglang 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. sglang/bench_latency.py +10 -6
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +0 -4
  4. sglang/lang/backend/runtime_endpoint.py +5 -2
  5. sglang/lang/interpreter.py +1 -1
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +29 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +1 -3
  14. sglang/srt/layers/attention_backend.py +480 -0
  15. sglang/srt/layers/flashinfer_utils.py +235 -0
  16. sglang/srt/layers/logits_processor.py +64 -77
  17. sglang/srt/layers/radix_attention.py +11 -161
  18. sglang/srt/layers/sampler.py +6 -25
  19. sglang/srt/layers/torchao_utils.py +75 -0
  20. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  21. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  22. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  23. sglang/srt/lora/lora.py +403 -0
  24. sglang/srt/lora/lora_config.py +43 -0
  25. sglang/srt/lora/lora_manager.py +256 -0
  26. sglang/srt/managers/controller_multi.py +1 -5
  27. sglang/srt/managers/controller_single.py +0 -5
  28. sglang/srt/managers/io_struct.py +16 -1
  29. sglang/srt/managers/policy_scheduler.py +122 -5
  30. sglang/srt/managers/schedule_batch.py +104 -71
  31. sglang/srt/managers/tokenizer_manager.py +17 -8
  32. sglang/srt/managers/tp_worker.py +181 -115
  33. sglang/srt/model_executor/cuda_graph_runner.py +58 -133
  34. sglang/srt/model_executor/forward_batch_info.py +35 -312
  35. sglang/srt/model_executor/model_runner.py +117 -131
  36. sglang/srt/models/baichuan.py +416 -0
  37. sglang/srt/models/chatglm.py +1 -5
  38. sglang/srt/models/commandr.py +1 -5
  39. sglang/srt/models/dbrx.py +1 -5
  40. sglang/srt/models/deepseek.py +1 -5
  41. sglang/srt/models/deepseek_v2.py +1 -5
  42. sglang/srt/models/exaone.py +1 -5
  43. sglang/srt/models/gemma.py +1 -5
  44. sglang/srt/models/gemma2.py +1 -5
  45. sglang/srt/models/gpt_bigcode.py +1 -5
  46. sglang/srt/models/grok.py +1 -5
  47. sglang/srt/models/internlm2.py +1 -5
  48. sglang/srt/models/llama.py +51 -5
  49. sglang/srt/models/llama_classification.py +1 -20
  50. sglang/srt/models/llava.py +30 -5
  51. sglang/srt/models/llavavid.py +2 -2
  52. sglang/srt/models/minicpm.py +1 -5
  53. sglang/srt/models/minicpm3.py +665 -0
  54. sglang/srt/models/mixtral.py +6 -5
  55. sglang/srt/models/mixtral_quant.py +1 -5
  56. sglang/srt/models/qwen.py +1 -5
  57. sglang/srt/models/qwen2.py +1 -5
  58. sglang/srt/models/qwen2_moe.py +6 -5
  59. sglang/srt/models/stablelm.py +1 -5
  60. sglang/srt/models/xverse.py +375 -0
  61. sglang/srt/models/xverse_moe.py +445 -0
  62. sglang/srt/openai_api/adapter.py +65 -46
  63. sglang/srt/openai_api/protocol.py +11 -3
  64. sglang/srt/sampling/sampling_batch_info.py +57 -44
  65. sglang/srt/server.py +24 -14
  66. sglang/srt/server_args.py +130 -28
  67. sglang/srt/utils.py +12 -0
  68. sglang/test/few_shot_gsm8k.py +132 -0
  69. sglang/test/runners.py +114 -22
  70. sglang/test/test_programs.py +7 -5
  71. sglang/test/test_utils.py +85 -1
  72. sglang/utils.py +32 -37
  73. sglang/version.py +1 -1
  74. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/METADATA +30 -18
  75. sglang-0.3.1.dist-info/RECORD +129 -0
  76. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
  77. sglang-0.3.0.dist-info/RECORD +0 -118
  78. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
  79. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -45,7 +45,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
45
  from sglang.srt.layers.layernorm import RMSNorm
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
- from sglang.srt.layers.sampler import Sampler
49
48
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
49
 
51
50
 
@@ -334,7 +333,6 @@ class QuantMixtralForCausalLM(nn.Module):
334
333
  self.model = MixtralModel(config, quant_config=quant_config)
335
334
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
336
335
  self.logits_processor = LogitsProcessor(config)
337
- self.sampler = Sampler()
338
336
 
339
337
  @torch.no_grad()
340
338
  def forward(
@@ -345,11 +343,9 @@ class QuantMixtralForCausalLM(nn.Module):
345
343
  input_embeds: torch.Tensor = None,
346
344
  ) -> torch.Tensor:
347
345
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
348
- logits_output = self.logits_processor(
346
+ return self.logits_processor(
349
347
  input_ids, hidden_states, self.lm_head.weight, input_metadata
350
348
  )
351
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
352
- return sample_output, logits_output
353
349
 
354
350
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
355
351
  stacked_params_mapping = [
sglang/srt/models/qwen.py CHANGED
@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul
39
39
  from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
- from sglang.srt.layers.sampler import Sampler
43
42
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
44
43
 
45
44
 
@@ -252,7 +251,6 @@ class QWenLMHeadModel(nn.Module):
252
251
  vocab_size = ((config.vocab_size + 63) // 64) * 64
253
252
  self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
254
253
  self.logits_processor = LogitsProcessor(config)
255
- self.sampler = Sampler()
256
254
 
257
255
  @torch.no_grad()
258
256
  def forward(
@@ -262,11 +260,9 @@ class QWenLMHeadModel(nn.Module):
262
260
  input_metadata: InputMetadata,
263
261
  ):
264
262
  hidden_states = self.transformer(input_ids, positions, input_metadata)
265
- logits_output = self.logits_processor(
263
+ return self.logits_processor(
266
264
  input_ids, hidden_states, self.lm_head.weight, input_metadata
267
265
  )
268
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
269
- return sample_output, logits_output
270
266
 
271
267
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
272
268
  stacked_params_mapping = [
@@ -40,7 +40,6 @@ from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.pooler import Pooler, PoolingType
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
- from sglang.srt.layers.sampler import Sampler
44
43
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
44
 
46
45
  Qwen2Config = None
@@ -277,7 +276,6 @@ class Qwen2ForCausalLM(nn.Module):
277
276
  self.model = Qwen2Model(config, quant_config=quant_config)
278
277
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
279
278
  self.logits_processor = LogitsProcessor(config)
280
- self.sampler = Sampler()
281
279
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
282
280
 
283
281
  @torch.no_grad()
@@ -291,11 +289,9 @@ class Qwen2ForCausalLM(nn.Module):
291
289
  ) -> torch.Tensor:
292
290
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
293
291
  if not get_embedding:
294
- logits_output = self.logits_processor(
292
+ return self.logits_processor(
295
293
  input_ids, hidden_states, self.lm_head.weight, input_metadata
296
294
  )
297
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
298
- return sample_output, logits_output
299
295
  else:
300
296
  return self.pooler(hidden_states, input_metadata)
301
297
 
@@ -47,7 +47,8 @@ from sglang.srt.layers.activation import SiluAndMul
47
47
  from sglang.srt.layers.layernorm import RMSNorm
48
48
  from sglang.srt.layers.logits_processor import LogitsProcessor
49
49
  from sglang.srt.layers.radix_attention import RadixAttention
50
- from sglang.srt.layers.sampler import Sampler
50
+ from sglang.srt.layers.torchao_utils import apply_torchao_config_
51
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
51
52
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
52
53
 
53
54
 
@@ -360,12 +361,12 @@ class Qwen2MoeForCausalLM(nn.Module):
360
361
  super().__init__()
361
362
  self.config = config
362
363
  self.quant_config = quant_config
364
+ self.torchao_config = global_server_args_dict["torchao_config"]
363
365
  self.model = Qwen2MoeModel(config, cache_config, quant_config)
364
366
  self.lm_head = ParallelLMHead(
365
367
  config.vocab_size, config.hidden_size, quant_config=quant_config
366
368
  )
367
369
  self.logits_processor = LogitsProcessor(config)
368
- self.sampler = Sampler()
369
370
 
370
371
  @torch.no_grad()
371
372
  def forward(
@@ -376,11 +377,9 @@ class Qwen2MoeForCausalLM(nn.Module):
376
377
  input_embeds: torch.Tensor = None,
377
378
  ) -> torch.Tensor:
378
379
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
379
- logits_output = self.logits_processor(
380
+ return self.logits_processor(
380
381
  input_ids, hidden_states, self.lm_head.weight, input_metadata
381
382
  )
382
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
383
- return sample_output, logits_output
384
383
 
385
384
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
386
385
  stacked_params_mapping = [
@@ -455,5 +454,7 @@ class Qwen2MoeForCausalLM(nn.Module):
455
454
  )
456
455
  weight_loader(param, loaded_weight)
457
456
 
457
+ apply_torchao_config_(self, params_dict, set(["proj.weight"]))
458
+
458
459
 
459
460
  EntryClass = Qwen2MoeForCausalLM
@@ -40,7 +40,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
40
  from sglang.srt.layers.activation import SiluAndMul
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
- from sglang.srt.layers.sampler import Sampler
44
43
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
44
 
46
45
 
@@ -250,7 +249,6 @@ class StableLmForCausalLM(nn.Module):
250
249
  self.model = StableLMEpochModel(config, quant_config=quant_config)
251
250
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
252
251
  self.logits_processor = LogitsProcessor(config)
253
- self.sampler = Sampler()
254
252
 
255
253
  @torch.no_grad()
256
254
  def forward(
@@ -261,11 +259,9 @@ class StableLmForCausalLM(nn.Module):
261
259
  input_embeds: torch.Tensor = None,
262
260
  ) -> torch.Tensor:
263
261
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
264
- logits_output = self.logits_processor(
262
+ return self.logits_processor(
265
263
  input_ids, hidden_states, self.lm_head.weight, input_metadata
266
264
  )
267
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
268
- return sample_output, logits_output
269
265
 
270
266
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
271
267
  stacked_params_mapping = [
@@ -0,0 +1,375 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ # Adapted from
17
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/xverse.py#L1
18
+ """Inference-only XVERSE model compatible with HuggingFace weights."""
19
+
20
+ from typing import Any, Dict, Iterable, Optional, Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+ from transformers import LlamaConfig
25
+ from vllm.config import CacheConfig
26
+ from vllm.distributed import get_tensor_model_parallel_world_size
27
+ from vllm.model_executor.layers.activation import SiluAndMul
28
+ from vllm.model_executor.layers.layernorm import RMSNorm
29
+ from vllm.model_executor.layers.linear import (
30
+ MergedColumnParallelLinear,
31
+ QKVParallelLinear,
32
+ RowParallelLinear,
33
+ )
34
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
35
+ from vllm.model_executor.layers.rotary_embedding import get_rope
36
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
37
+ ParallelLMHead,
38
+ VocabParallelEmbedding,
39
+ )
40
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
+
42
+ from sglang.srt.layers.logits_processor import LogitsProcessor
43
+ from sglang.srt.layers.radix_attention import RadixAttention
44
+ from sglang.srt.model_executor.model_runner import InputMetadata
45
+
46
+
47
+ class XverseMLP(nn.Module):
48
+ def __init__(
49
+ self,
50
+ hidden_size: int,
51
+ intermediate_size: int,
52
+ hidden_act: str,
53
+ quant_config: Optional[QuantizationConfig] = None,
54
+ prefix: str = "",
55
+ ) -> None:
56
+ super().__init__()
57
+ self.gate_up_proj = MergedColumnParallelLinear(
58
+ hidden_size,
59
+ [intermediate_size] * 2,
60
+ bias=False,
61
+ quant_config=quant_config,
62
+ prefix=f"{prefix}.gate_up_proj",
63
+ )
64
+ self.down_proj = RowParallelLinear(
65
+ intermediate_size,
66
+ hidden_size,
67
+ bias=False,
68
+ quant_config=quant_config,
69
+ prefix=f"{prefix}.down_proj",
70
+ )
71
+ if hidden_act != "silu":
72
+ raise ValueError(
73
+ f"Unsupported activation: {hidden_act}. "
74
+ "Only silu is supported for now."
75
+ )
76
+ self.act_fn = SiluAndMul()
77
+
78
+ def forward(self, x):
79
+ gate_up, _ = self.gate_up_proj(x)
80
+ x = self.act_fn(gate_up)
81
+ x, _ = self.down_proj(x)
82
+ return x
83
+
84
+
85
+ class XverseAttention(nn.Module):
86
+ def __init__(
87
+ self,
88
+ config: LlamaConfig,
89
+ hidden_size: int,
90
+ num_heads: int,
91
+ num_kv_heads: int,
92
+ layer_id: int = 0,
93
+ rope_theta: float = 10000,
94
+ rope_scaling: Optional[Dict[str, Any]] = None,
95
+ rope_is_neox_style: bool = True,
96
+ max_position_embeddings: int = 8192,
97
+ quant_config: Optional[QuantizationConfig] = None,
98
+ prefix: str = "",
99
+ ) -> None:
100
+ super().__init__()
101
+ self.hidden_size = hidden_size
102
+ tp_size = get_tensor_model_parallel_world_size()
103
+ self.total_num_heads = num_heads
104
+ assert self.total_num_heads % tp_size == 0
105
+ self.num_heads = self.total_num_heads // tp_size
106
+ self.total_num_kv_heads = num_kv_heads
107
+ if self.total_num_kv_heads >= tp_size:
108
+ # Number of KV heads is greater than TP size, so we partition
109
+ # the KV heads across multiple tensor parallel GPUs.
110
+ assert self.total_num_kv_heads % tp_size == 0
111
+ else:
112
+ # Number of KV heads is less than TP size, so we replicate
113
+ # the KV heads across multiple tensor parallel GPUs.
114
+ assert tp_size % self.total_num_kv_heads == 0
115
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
116
+ # MistralConfig has an optional head_dim introduced by Mistral-Nemo
117
+ self.head_dim = getattr(
118
+ config, "head_dim", self.hidden_size // self.total_num_heads
119
+ )
120
+ self.q_size = self.num_heads * self.head_dim
121
+ self.kv_size = self.num_kv_heads * self.head_dim
122
+ self.scaling = self.head_dim**-0.5
123
+ self.rope_theta = rope_theta
124
+ self.max_position_embeddings = max_position_embeddings
125
+
126
+ self.qkv_proj = QKVParallelLinear(
127
+ hidden_size,
128
+ self.head_dim,
129
+ self.total_num_heads,
130
+ self.total_num_kv_heads,
131
+ bias=False,
132
+ quant_config=quant_config,
133
+ prefix=f"{prefix}.qkv_proj",
134
+ )
135
+ self.o_proj = RowParallelLinear(
136
+ self.total_num_heads * self.head_dim,
137
+ hidden_size,
138
+ bias=False,
139
+ quant_config=quant_config,
140
+ prefix=f"{prefix}.o_proj",
141
+ )
142
+
143
+ self.rotary_emb = get_rope(
144
+ self.head_dim,
145
+ rotary_dim=self.head_dim,
146
+ max_position=max_position_embeddings,
147
+ base=rope_theta,
148
+ rope_scaling=rope_scaling,
149
+ is_neox_style=rope_is_neox_style,
150
+ )
151
+ self.attn = RadixAttention(
152
+ self.num_heads,
153
+ self.head_dim,
154
+ self.scaling,
155
+ num_kv_heads=self.num_kv_heads,
156
+ layer_id=layer_id,
157
+ )
158
+
159
+ def forward(
160
+ self,
161
+ positions: torch.Tensor,
162
+ hidden_states: torch.Tensor,
163
+ input_metadata: InputMetadata,
164
+ ) -> torch.Tensor:
165
+ qkv, _ = self.qkv_proj(hidden_states)
166
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
167
+ q, k = self.rotary_emb(positions, q, k)
168
+ attn_output = self.attn(q, k, v, input_metadata)
169
+ output, _ = self.o_proj(attn_output)
170
+ return output
171
+
172
+
173
+ class XverseDecoderLayer(nn.Module):
174
+ def __init__(
175
+ self,
176
+ config: LlamaConfig,
177
+ layer_id: int = 0,
178
+ quant_config: Optional[QuantizationConfig] = None,
179
+ prefix: str = "",
180
+ ) -> None:
181
+ super().__init__()
182
+ self.hidden_size = config.hidden_size
183
+ rope_theta = getattr(config, "rope_theta", 10000)
184
+ rope_scaling = getattr(config, "rope_scaling", None)
185
+ if rope_scaling is not None and getattr(
186
+ config, "original_max_position_embeddings", None
187
+ ):
188
+ rope_scaling["original_max_position_embeddings"] = (
189
+ config.original_max_position_embeddings
190
+ )
191
+ rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
192
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
193
+ num_kv_heads = getattr(
194
+ config, "num_key_value_heads", config.num_attention_heads
195
+ )
196
+ self.self_attn = XverseAttention(
197
+ config=config,
198
+ hidden_size=self.hidden_size,
199
+ num_heads=config.num_attention_heads,
200
+ num_kv_heads=num_kv_heads,
201
+ layer_id=layer_id,
202
+ rope_theta=rope_theta,
203
+ rope_scaling=rope_scaling,
204
+ rope_is_neox_style=rope_is_neox_style,
205
+ max_position_embeddings=max_position_embeddings,
206
+ quant_config=quant_config,
207
+ prefix=f"{prefix}.self_attn",
208
+ )
209
+ self.mlp = XverseMLP(
210
+ hidden_size=self.hidden_size,
211
+ intermediate_size=config.intermediate_size,
212
+ hidden_act=config.hidden_act,
213
+ quant_config=quant_config,
214
+ prefix=f"{prefix}.mlp",
215
+ )
216
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
217
+ self.post_attention_layernorm = RMSNorm(
218
+ config.hidden_size, eps=config.rms_norm_eps
219
+ )
220
+
221
+ def forward(
222
+ self,
223
+ positions: torch.Tensor,
224
+ hidden_states: torch.Tensor,
225
+ input_metadata: InputMetadata,
226
+ residual: Optional[torch.Tensor],
227
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
228
+ # Self Attention
229
+ if residual is None:
230
+ residual = hidden_states
231
+ hidden_states = self.input_layernorm(hidden_states)
232
+ else:
233
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
234
+ hidden_states = self.self_attn(
235
+ positions=positions,
236
+ hidden_states=hidden_states,
237
+ input_metadata=input_metadata,
238
+ )
239
+
240
+ # Fully Connected
241
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
242
+ hidden_states = self.mlp(hidden_states)
243
+ return hidden_states, residual
244
+
245
+
246
+ class XverseModel(nn.Module):
247
+ def __init__(
248
+ self,
249
+ config: LlamaConfig,
250
+ quant_config: Optional[QuantizationConfig] = None,
251
+ ) -> None:
252
+ super().__init__()
253
+ self.config = config
254
+ self.padding_idx = config.pad_token_id
255
+ self.vocab_size = config.vocab_size
256
+ self.embed_tokens = VocabParallelEmbedding(
257
+ config.vocab_size,
258
+ config.hidden_size,
259
+ )
260
+ self.layers = nn.ModuleList(
261
+ [
262
+ XverseDecoderLayer(
263
+ config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
264
+ )
265
+ for i in range(config.num_hidden_layers)
266
+ ]
267
+ )
268
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
269
+
270
+ def forward(
271
+ self,
272
+ input_ids: torch.Tensor,
273
+ positions: torch.Tensor,
274
+ input_metadata: InputMetadata,
275
+ input_embeds: torch.Tensor = None,
276
+ ) -> torch.Tensor:
277
+ if input_embeds is None:
278
+ hidden_states = self.embed_tokens(input_ids)
279
+ else:
280
+ hidden_states = input_embeds
281
+ residual = None
282
+ for i in range(len(self.layers)):
283
+ layer = self.layers[i]
284
+ hidden_states, residual = layer(
285
+ positions,
286
+ hidden_states,
287
+ input_metadata,
288
+ residual,
289
+ )
290
+ # print(f"layer[{i}].hidden_states: {hidden_states}")
291
+ hidden_states, _ = self.norm(hidden_states, residual)
292
+ return hidden_states
293
+
294
+
295
+ class XverseForCausalLM(nn.Module):
296
+ def __init__(
297
+ self,
298
+ config: LlamaConfig,
299
+ quant_config: Optional[QuantizationConfig] = None,
300
+ cache_config: Optional[CacheConfig] = None,
301
+ efficient_weight_load=False,
302
+ ) -> None:
303
+ super().__init__()
304
+ self.config = config
305
+ self.quant_config = quant_config
306
+ self.model = XverseModel(config, quant_config=quant_config)
307
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
308
+ self.logits_processor = LogitsProcessor(config)
309
+
310
+ self.param_dict = dict(self.named_parameters())
311
+
312
+ @torch.no_grad()
313
+ def forward(
314
+ self,
315
+ input_ids: torch.Tensor,
316
+ positions: torch.Tensor,
317
+ input_metadata: InputMetadata,
318
+ input_embeds: torch.Tensor = None,
319
+ ) -> torch.Tensor:
320
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
321
+ return self.logits_processor(
322
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
323
+ )
324
+
325
+ def load_weights(
326
+ self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
327
+ ):
328
+ stacked_params_mapping = [
329
+ # (param_name, shard_name, shard_id)
330
+ ("qkv_proj", "q_proj", "q"),
331
+ ("qkv_proj", "k_proj", "k"),
332
+ ("qkv_proj", "v_proj", "v"),
333
+ ("gate_up_proj", "gate_proj", 0),
334
+ ("gate_up_proj", "up_proj", 1),
335
+ ]
336
+ params_dict = self.param_dict
337
+
338
+ def load_weights_per_param(name, loaded_weight):
339
+ if "rotary_emb.inv_freq" in name or "projector" in name:
340
+ return
341
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
342
+ # Models trained using ColossalAI may include these tensors in
343
+ # the checkpoint. Skip them.
344
+ return
345
+ for param_name, weight_name, shard_id in stacked_params_mapping:
346
+ if weight_name not in name:
347
+ continue
348
+ name = name.replace(weight_name, param_name)
349
+ # Skip loading extra bias for GPTQ models.
350
+ if name.endswith(".bias") and name not in params_dict:
351
+ continue
352
+ if name.startswith("model.vision_tower") and name not in params_dict:
353
+ continue
354
+ param = params_dict[name]
355
+ weight_loader = param.weight_loader
356
+ weight_loader(param, loaded_weight, shard_id)
357
+ break
358
+ else:
359
+ # Skip loading extra bias for GPTQ models.
360
+ if name.endswith(".bias") and name not in params_dict:
361
+ return
362
+ if name.startswith("model.vision_tower") and name not in params_dict:
363
+ return
364
+ param = params_dict[name]
365
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
366
+ weight_loader(param, loaded_weight)
367
+
368
+ if name is None or loaded_weight is None:
369
+ for name, loaded_weight in weights:
370
+ load_weights_per_param(name, loaded_weight)
371
+ else:
372
+ load_weights_per_param(name, loaded_weight)
373
+
374
+
375
+ EntryClass = XverseForCausalLM