sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -1,33 +1,30 @@
1
1
  # Adapted from llama2.py
2
2
  # Modify details for the adaptation of Qwen2 model.
3
3
  """Inference-only Qwen2 model compatible with HuggingFace weights."""
4
- from typing import Any, Dict, List, Optional, Tuple
4
+ from typing import Any, Dict, Iterable, Optional, Tuple
5
5
 
6
6
  import torch
7
- from sglang.srt.layers.logits_processor import LogitsProcessor
8
- from sglang.srt.layers.radix_attention import RadixAttention
9
- from sglang.srt.managers.router.model_runner import InputMetadata
10
7
  from torch import nn
8
+ from vllm.config import CacheConfig
9
+ from vllm.distributed import get_tensor_model_parallel_world_size
11
10
  from vllm.model_executor.layers.activation import SiluAndMul
12
11
  from vllm.model_executor.layers.layernorm import RMSNorm
13
12
  from vllm.model_executor.layers.linear import (
14
- LinearMethodBase,
15
13
  MergedColumnParallelLinear,
16
14
  QKVParallelLinear,
17
15
  RowParallelLinear,
18
16
  )
17
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
19
18
  from vllm.model_executor.layers.rotary_embedding import get_rope
20
19
  from vllm.model_executor.layers.vocab_parallel_embedding import (
21
20
  ParallelLMHead,
22
21
  VocabParallelEmbedding,
23
22
  )
24
- from vllm.model_executor.parallel_utils.parallel_state import (
25
- get_tensor_model_parallel_world_size,
26
- )
27
- from vllm.model_executor.weight_utils import (
28
- default_weight_loader,
29
- hf_model_weights_iterator,
30
- )
23
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
24
+
25
+ from sglang.srt.layers.logits_processor import LogitsProcessor
26
+ from sglang.srt.layers.radix_attention import RadixAttention
27
+ from sglang.srt.managers.controller.model_runner import InputMetadata
31
28
 
32
29
  Qwen2Config = None
33
30
 
@@ -38,17 +35,20 @@ class Qwen2MLP(nn.Module):
38
35
  hidden_size: int,
39
36
  intermediate_size: int,
40
37
  hidden_act: str,
41
- linear_method: Optional[LinearMethodBase] = None,
38
+ quant_config: Optional[QuantizationConfig] = None,
42
39
  ) -> None:
43
40
  super().__init__()
44
41
  self.gate_up_proj = MergedColumnParallelLinear(
45
42
  hidden_size,
46
43
  [intermediate_size] * 2,
47
44
  bias=False,
48
- linear_method=linear_method,
45
+ quant_config=quant_config,
49
46
  )
50
47
  self.down_proj = RowParallelLinear(
51
- intermediate_size, hidden_size, bias=False, linear_method=linear_method
48
+ intermediate_size,
49
+ hidden_size,
50
+ bias=False,
51
+ quant_config=quant_config,
52
52
  )
53
53
  if hidden_act != "silu":
54
54
  raise ValueError(
@@ -74,7 +74,7 @@ class Qwen2Attention(nn.Module):
74
74
  rope_theta: float = 1000000,
75
75
  rope_scaling: Optional[Dict[str, Any]] = None,
76
76
  max_position_embeddings: int = 32768,
77
- linear_method: Optional[LinearMethodBase] = None,
77
+ quant_config: Optional[QuantizationConfig] = None,
78
78
  ) -> None:
79
79
  super().__init__()
80
80
  self.hidden_size = hidden_size
@@ -105,13 +105,13 @@ class Qwen2Attention(nn.Module):
105
105
  self.total_num_heads,
106
106
  self.total_num_kv_heads,
107
107
  bias=True,
108
- linear_method=linear_method,
108
+ quant_config=quant_config,
109
109
  )
110
110
  self.o_proj = RowParallelLinear(
111
111
  self.total_num_heads * self.head_dim,
112
112
  hidden_size,
113
113
  bias=False,
114
- linear_method=linear_method,
114
+ quant_config=quant_config,
115
115
  )
116
116
 
117
117
  self.rotary_emb = get_rope(
@@ -148,7 +148,7 @@ class Qwen2DecoderLayer(nn.Module):
148
148
  self,
149
149
  config: Qwen2Config,
150
150
  layer_id: int = 0,
151
- linear_method: Optional[LinearMethodBase] = None,
151
+ quant_config: Optional[QuantizationConfig] = None,
152
152
  ) -> None:
153
153
  super().__init__()
154
154
  self.hidden_size = config.hidden_size
@@ -163,13 +163,13 @@ class Qwen2DecoderLayer(nn.Module):
163
163
  rope_theta=rope_theta,
164
164
  rope_scaling=rope_scaling,
165
165
  max_position_embeddings=max_position_embeddings,
166
- linear_method=linear_method,
166
+ quant_config=quant_config,
167
167
  )
168
168
  self.mlp = Qwen2MLP(
169
169
  hidden_size=self.hidden_size,
170
170
  intermediate_size=config.intermediate_size,
171
171
  hidden_act=config.hidden_act,
172
- linear_method=linear_method,
172
+ quant_config=quant_config,
173
173
  )
174
174
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
175
175
  self.post_attention_layernorm = RMSNorm(
@@ -205,7 +205,7 @@ class Qwen2Model(nn.Module):
205
205
  def __init__(
206
206
  self,
207
207
  config: Qwen2Config,
208
- linear_method: Optional[LinearMethodBase] = None,
208
+ quant_config: Optional[QuantizationConfig] = None,
209
209
  ) -> None:
210
210
  super().__init__()
211
211
  self.config = config
@@ -217,7 +217,7 @@ class Qwen2Model(nn.Module):
217
217
  )
218
218
  self.layers = nn.ModuleList(
219
219
  [
220
- Qwen2DecoderLayer(config, i, linear_method)
220
+ Qwen2DecoderLayer(config, i, quant_config=quant_config)
221
221
  for i in range(config.num_hidden_layers)
222
222
  ]
223
223
  )
@@ -251,12 +251,13 @@ class Qwen2ForCausalLM(nn.Module):
251
251
  def __init__(
252
252
  self,
253
253
  config: Qwen2Config,
254
- linear_method: Optional[LinearMethodBase] = None,
254
+ quant_config: Optional[QuantizationConfig] = None,
255
+ cache_config: Optional[CacheConfig] = None,
255
256
  ) -> None:
256
257
  super().__init__()
257
258
  self.config = config
258
- self.linear_method = linear_method
259
- self.model = Qwen2Model(config, linear_method)
259
+ self.quant_config = quant_config
260
+ self.model = Qwen2Model(config, quant_config=quant_config)
260
261
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
261
262
  self.logits_processor = LogitsProcessor(config)
262
263
 
@@ -272,13 +273,7 @@ class Qwen2ForCausalLM(nn.Module):
272
273
  input_ids, hidden_states, self.lm_head.weight, input_metadata
273
274
  )
274
275
 
275
- def load_weights(
276
- self,
277
- model_name_or_path: str,
278
- cache_dir: Optional[str] = None,
279
- load_format: str = "auto",
280
- revision: Optional[str] = None,
281
- ):
276
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
282
277
  stacked_params_mapping = [
283
278
  # (param_name, shard_name, shard_id)
284
279
  ("qkv_proj", "q_proj", "q"),
@@ -288,9 +283,7 @@ class Qwen2ForCausalLM(nn.Module):
288
283
  ("gate_up_proj", "up_proj", 1),
289
284
  ]
290
285
  params_dict = dict(self.named_parameters())
291
- for name, loaded_weight in hf_model_weights_iterator(
292
- model_name_or_path, cache_dir, load_format, revision
293
- ):
286
+ for name, loaded_weight in weights:
294
287
  if "rotary_emb.inv_freq" in name or "projector" in name:
295
288
  continue
296
289
  if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
@@ -304,6 +297,8 @@ class Qwen2ForCausalLM(nn.Module):
304
297
  # Skip loading extra bias for GPTQ models.
305
298
  if name.endswith(".bias") and name not in params_dict:
306
299
  continue
300
+ if name.startswith("model.vision_tower") and name not in params_dict:
301
+ continue
307
302
  param = params_dict[name]
308
303
  weight_loader = param.weight_loader
309
304
  weight_loader(param, loaded_weight, shard_id)
@@ -312,6 +307,8 @@ class Qwen2ForCausalLM(nn.Module):
312
307
  # Skip loading extra bias for GPTQ models.
313
308
  if name.endswith(".bias") and name not in params_dict:
314
309
  continue
310
+ if name.startswith("model.vision_tower") and name not in params_dict:
311
+ continue
315
312
  param = params_dict[name]
316
313
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
317
314
  weight_loader(param, loaded_weight)
@@ -0,0 +1,473 @@
1
+ # coding=utf-8
2
+ # Adapted from
3
+ # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_moe.py
4
+ """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
5
+ from typing import Any, Dict, Iterable, List, Optional, Tuple
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+ from transformers import PretrainedConfig
11
+ from vllm.config import CacheConfig
12
+ from vllm.distributed import (
13
+ get_tensor_model_parallel_world_size,
14
+ tensor_model_parallel_all_reduce,
15
+ )
16
+ from vllm.model_executor.layers.activation import SiluAndMul
17
+ from vllm.model_executor.layers.fused_moe import FusedMoE
18
+ from vllm.model_executor.layers.layernorm import RMSNorm
19
+ from vllm.model_executor.layers.linear import (
20
+ MergedColumnParallelLinear,
21
+ QKVParallelLinear,
22
+ ReplicatedLinear,
23
+ RowParallelLinear,
24
+ )
25
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
26
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
27
+ from vllm.model_executor.layers.rotary_embedding import get_rope
28
+ from vllm.model_executor.layers.sampler import Sampler
29
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
30
+ ParallelLMHead,
31
+ VocabParallelEmbedding,
32
+ )
33
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
34
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
35
+ from vllm.sequence import IntermediateTensors, SamplerOutput
36
+
37
+ from sglang.srt.layers.logits_processor import LogitsProcessor
38
+ from sglang.srt.layers.radix_attention import RadixAttention
39
+ from sglang.srt.managers.controller.model_runner import InputMetadata
40
+
41
+
42
+ class Qwen2MoeMLP(nn.Module):
43
+ def __init__(
44
+ self,
45
+ hidden_size: int,
46
+ intermediate_size: int,
47
+ hidden_act: str,
48
+ quant_config: Optional[QuantizationConfig] = None,
49
+ reduce_results: bool = True,
50
+ ) -> None:
51
+ super().__init__()
52
+ self.gate_up_proj = MergedColumnParallelLinear(
53
+ hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
54
+ )
55
+ self.down_proj = RowParallelLinear(
56
+ intermediate_size,
57
+ hidden_size,
58
+ bias=False,
59
+ quant_config=quant_config,
60
+ reduce_results=reduce_results,
61
+ )
62
+ if hidden_act != "silu":
63
+ raise ValueError(
64
+ f"Unsupported activation: {hidden_act}. "
65
+ "Only silu is supported for now."
66
+ )
67
+ self.act_fn = SiluAndMul()
68
+
69
+ def forward(self, x):
70
+ gate_up, _ = self.gate_up_proj(x)
71
+ x = self.act_fn(gate_up)
72
+ x, _ = self.down_proj(x)
73
+ return x
74
+
75
+
76
+ class Qwen2MoeSparseMoeBlock(nn.Module):
77
+ def __init__(
78
+ self,
79
+ config: PretrainedConfig,
80
+ quant_config: Optional[QuantizationConfig] = None,
81
+ ):
82
+ super().__init__()
83
+ self.tp_size = get_tensor_model_parallel_world_size()
84
+
85
+ if self.tp_size > config.num_experts:
86
+ raise ValueError(
87
+ f"Tensor parallel size {self.tp_size} is greater than "
88
+ f"the number of experts {config.num_experts}."
89
+ )
90
+
91
+ self.experts = FusedMoE(
92
+ num_experts=config.num_experts,
93
+ top_k=config.num_experts_per_tok,
94
+ hidden_size=config.hidden_size,
95
+ intermediate_size=config.moe_intermediate_size,
96
+ reduce_results=False,
97
+ renormalize=config.norm_topk_prob,
98
+ quant_config=quant_config,
99
+ )
100
+
101
+ self.gate = ReplicatedLinear(
102
+ config.hidden_size, config.num_experts, bias=False, quant_config=None
103
+ )
104
+ if config.shared_expert_intermediate_size > 0:
105
+ self.shared_expert = Qwen2MoeMLP(
106
+ hidden_size=config.hidden_size,
107
+ intermediate_size=config.shared_expert_intermediate_size,
108
+ hidden_act=config.hidden_act,
109
+ quant_config=quant_config,
110
+ reduce_results=False,
111
+ )
112
+ else:
113
+ self.shared_expert = None
114
+ self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
115
+
116
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
117
+ num_tokens, hidden_dim = hidden_states.shape
118
+ hidden_states = hidden_states.view(-1, hidden_dim)
119
+ shared_output = None
120
+ if self.shared_expert is not None:
121
+ shared_output = self.shared_expert(hidden_states)
122
+ if self.shared_expert_gate is not None:
123
+ shared_output = (
124
+ F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output
125
+ )
126
+
127
+ # router_logits: (num_tokens, n_experts)
128
+ router_logits, _ = self.gate(hidden_states)
129
+ final_hidden_states = self.experts(
130
+ hidden_states=hidden_states, router_logits=router_logits
131
+ )
132
+ if shared_output is not None:
133
+ final_hidden_states = final_hidden_states + shared_output
134
+ if self.tp_size > 1:
135
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
136
+
137
+ return final_hidden_states.view(num_tokens, hidden_dim)
138
+
139
+
140
+ class Qwen2MoeAttention(nn.Module):
141
+ def __init__(
142
+ self,
143
+ hidden_size: int,
144
+ num_heads: int,
145
+ num_kv_heads: int,
146
+ layer_id: int = 0,
147
+ rope_theta: float = 10000,
148
+ rope_scaling: Optional[Dict[str, Any]] = None,
149
+ max_position_embeddings: int = 8192,
150
+ cache_config: Optional[CacheConfig] = None,
151
+ quant_config: Optional[QuantizationConfig] = None,
152
+ ) -> None:
153
+ super().__init__()
154
+ self.hidden_size = hidden_size
155
+ tp_size = get_tensor_model_parallel_world_size()
156
+ self.total_num_heads = num_heads
157
+ assert self.total_num_heads % tp_size == 0
158
+ self.num_heads = self.total_num_heads // tp_size
159
+ self.total_num_kv_heads = num_kv_heads
160
+ if self.total_num_kv_heads >= tp_size:
161
+ # Number of KV heads is greater than TP size, so we partition
162
+ # the KV heads across multiple tensor parallel GPUs.
163
+ assert self.total_num_kv_heads % tp_size == 0
164
+ else:
165
+ # Number of KV heads is less than TP size, so we replicate
166
+ # the KV heads across multiple tensor parallel GPUs.
167
+ assert tp_size % self.total_num_kv_heads == 0
168
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
169
+ self.head_dim = hidden_size // self.total_num_heads
170
+ self.q_size = self.num_heads * self.head_dim
171
+ self.kv_size = self.num_kv_heads * self.head_dim
172
+ self.scaling = self.head_dim**-0.5
173
+ self.rope_theta = rope_theta
174
+ self.max_position_embeddings = max_position_embeddings
175
+
176
+ self.qkv_proj = QKVParallelLinear(
177
+ hidden_size,
178
+ self.head_dim,
179
+ self.total_num_heads,
180
+ self.total_num_kv_heads,
181
+ bias=True,
182
+ quant_config=quant_config,
183
+ )
184
+
185
+ self.o_proj = RowParallelLinear(
186
+ self.total_num_heads * self.head_dim,
187
+ hidden_size,
188
+ bias=False,
189
+ quant_config=quant_config,
190
+ )
191
+
192
+ self.rotary_emb = get_rope(
193
+ self.head_dim,
194
+ rotary_dim=self.head_dim,
195
+ max_position=max_position_embeddings,
196
+ base=rope_theta,
197
+ rope_scaling=rope_scaling,
198
+ )
199
+ self.attn = RadixAttention(
200
+ self.num_heads,
201
+ self.head_dim,
202
+ self.scaling,
203
+ num_kv_heads=self.num_kv_heads,
204
+ layer_id=layer_id,
205
+ )
206
+
207
+ def forward(
208
+ self,
209
+ positions: torch.Tensor,
210
+ hidden_states: torch.Tensor,
211
+ input_metadata: InputMetadata,
212
+ ) -> torch.Tensor:
213
+ qkv, _ = self.qkv_proj(hidden_states)
214
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
215
+ q, k = self.rotary_emb(positions, q, k)
216
+ attn_output = self.attn(q, k, v, input_metadata)
217
+ output, _ = self.o_proj(attn_output)
218
+ return output
219
+
220
+
221
+ class Qwen2MoeDecoderLayer(nn.Module):
222
+ def __init__(
223
+ self,
224
+ config: PretrainedConfig,
225
+ layer_id: int,
226
+ cache_config: Optional[CacheConfig] = None,
227
+ quant_config: Optional[QuantizationConfig] = None,
228
+ ) -> None:
229
+ super().__init__()
230
+ self.hidden_size = config.hidden_size
231
+ rope_theta = getattr(config, "rope_theta", 10000)
232
+ rope_scaling = getattr(config, "rope_scaling", None)
233
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
234
+ self.self_attn = Qwen2MoeAttention(
235
+ hidden_size=self.hidden_size,
236
+ num_heads=config.num_attention_heads,
237
+ num_kv_heads=config.num_key_value_heads,
238
+ layer_id=layer_id,
239
+ rope_theta=rope_theta,
240
+ rope_scaling=rope_scaling,
241
+ max_position_embeddings=max_position_embeddings,
242
+ cache_config=cache_config,
243
+ quant_config=quant_config,
244
+ )
245
+
246
+ # Note: Qwen/Qwen2-57B-A14B-Instruct does not have
247
+ # `mlp_only_layers` in the config.
248
+ mlp_only_layers = (
249
+ [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
250
+ )
251
+ if (layer_id not in mlp_only_layers) and (
252
+ config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
253
+ ):
254
+ self.mlp = Qwen2MoeSparseMoeBlock(config=config, quant_config=quant_config)
255
+ else:
256
+ self.mlp = Qwen2MoeMLP(
257
+ hidden_size=config.hidden_size,
258
+ intermediate_size=config.intermediate_size,
259
+ hidden_act=config.hidden_act,
260
+ quant_config=quant_config,
261
+ )
262
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
263
+ self.post_attention_layernorm = RMSNorm(
264
+ config.hidden_size, eps=config.rms_norm_eps
265
+ )
266
+
267
+ def forward(
268
+ self,
269
+ positions: torch.Tensor,
270
+ hidden_states: torch.Tensor,
271
+ input_metadata: InputMetadata,
272
+ residual: Optional[torch.Tensor],
273
+ ) -> torch.Tensor:
274
+ # Self Attention
275
+ if residual is None:
276
+ residual = hidden_states
277
+ hidden_states = self.input_layernorm(hidden_states)
278
+ else:
279
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
280
+ hidden_states = self.self_attn(
281
+ positions=positions,
282
+ hidden_states=hidden_states,
283
+ input_metadata=input_metadata,
284
+ )
285
+
286
+ # Fully Connected
287
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
288
+ hidden_states = self.mlp(hidden_states)
289
+ return hidden_states, residual
290
+
291
+
292
+ class Qwen2MoeModel(nn.Module):
293
+ def __init__(
294
+ self,
295
+ config: PretrainedConfig,
296
+ cache_config: Optional[CacheConfig] = None,
297
+ quant_config: Optional[QuantizationConfig] = None,
298
+ ) -> None:
299
+ super().__init__()
300
+ self.padding_idx = config.pad_token_id
301
+ self.vocab_size = config.vocab_size
302
+
303
+ self.embed_tokens = VocabParallelEmbedding(
304
+ config.vocab_size,
305
+ config.hidden_size,
306
+ )
307
+ self.layers = nn.ModuleList(
308
+ [
309
+ Qwen2MoeDecoderLayer(
310
+ config, layer_id, cache_config, quant_config=quant_config
311
+ )
312
+ for layer_id in range(config.num_hidden_layers)
313
+ ]
314
+ )
315
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
316
+
317
+ def forward(
318
+ self,
319
+ input_ids: torch.Tensor,
320
+ positions: torch.Tensor,
321
+ input_metadata: InputMetadata,
322
+ input_embeds: torch.Tensor = None,
323
+ ) -> torch.Tensor:
324
+ if input_embeds is None:
325
+ hidden_states = self.embed_tokens(input_ids)
326
+ else:
327
+ hidden_states = input_embeds
328
+ residual = None
329
+ for i in range(len(self.layers)):
330
+ layer = self.layers[i]
331
+ hidden_states, residual = layer(
332
+ positions, hidden_states, input_metadata, residual
333
+ )
334
+ hidden_states, _ = self.norm(hidden_states, residual)
335
+ return hidden_states
336
+
337
+
338
+ class Qwen2MoeForCausalLM(nn.Module):
339
+
340
+ fall_back_to_pt_during_load = False
341
+
342
+ def __init__(
343
+ self,
344
+ config: PretrainedConfig,
345
+ cache_config: Optional[CacheConfig] = None,
346
+ quant_config: Optional[QuantizationConfig] = None,
347
+ ) -> None:
348
+ super().__init__()
349
+ self.config = config
350
+ self.quant_config = quant_config
351
+ self.model = Qwen2MoeModel(config, cache_config, quant_config)
352
+ self.lm_head = ParallelLMHead(
353
+ config.vocab_size, config.hidden_size, quant_config=quant_config
354
+ )
355
+ self.logits_processor = LogitsProcessor(config)
356
+ self.sampler = Sampler()
357
+
358
+ def forward(
359
+ self,
360
+ input_ids: torch.Tensor,
361
+ positions: torch.Tensor,
362
+ input_metadata: InputMetadata,
363
+ input_embeds: torch.Tensor = None,
364
+ ) -> torch.Tensor:
365
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
366
+ return self.logits_processor(
367
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
368
+ )
369
+
370
+ def compute_logits(
371
+ self,
372
+ input_ids: torch.Tensor,
373
+ hidden_states: torch.Tensor,
374
+ input_metadata: InputMetadata,
375
+ ) -> torch.Tensor:
376
+ logits = self.logits_processor(
377
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
378
+ )
379
+ return logits
380
+
381
+ def sample(
382
+ self,
383
+ logits: Optional[torch.Tensor],
384
+ sampling_metadata: SamplingMetadata,
385
+ ) -> Optional[SamplerOutput]:
386
+ next_tokens = self.sampler(logits, sampling_metadata)
387
+ return next_tokens
388
+
389
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
390
+ stacked_params_mapping = [
391
+ # (param_name, shard_name, shard_id)
392
+ ("qkv_proj", "q_proj", "q"),
393
+ ("qkv_proj", "k_proj", "k"),
394
+ ("qkv_proj", "v_proj", "v"),
395
+ ("gate_up_proj", "gate_proj", 0),
396
+ ("gate_up_proj", "up_proj", 1),
397
+ ]
398
+
399
+ expert_params_mapping = [
400
+ # These are the weights for the experts
401
+ # (param_name, weight_name, expert_id, shard_id)
402
+ (
403
+ "experts.w13_weight"
404
+ if weight_name in ["gate_proj", "up_proj"]
405
+ else "experts.w2_weight",
406
+ f"experts.{expert_id}.{weight_name}.weight",
407
+ expert_id,
408
+ shard_id,
409
+ )
410
+ for expert_id in range(self.config.num_experts)
411
+ for shard_id, weight_name in enumerate(
412
+ ["gate_proj", "down_proj", "up_proj"]
413
+ )
414
+ ]
415
+
416
+ params_dict = dict(self.named_parameters())
417
+ for name, loaded_weight in weights:
418
+ if "rotary_emb.inv_freq" in name:
419
+ continue
420
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
421
+ # Skip non-stacked layers and experts (experts handled below).
422
+ if weight_name not in name:
423
+ continue
424
+ # We have mlp.experts[0].gate_proj in the checkpoint.
425
+ # Since we handle the experts below in expert_params_mapping,
426
+ # we need to skip here BEFORE we update the name, otherwise
427
+ # name will be updated to mlp.experts[0].gate_up_proj, which
428
+ # will then be updated below in expert_params_mapping
429
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
430
+ if "mlp.experts" in name:
431
+ continue
432
+ name = name.replace(weight_name, param_name)
433
+ # Skip loading extra bias for GPTQ models.
434
+ if name.endswith(".bias") and name not in params_dict:
435
+ continue
436
+ if name not in params_dict:
437
+ continue
438
+
439
+ param = params_dict[name]
440
+ weight_loader = param.weight_loader
441
+ weight_loader(param, loaded_weight, shard_id)
442
+ break
443
+ else:
444
+ for mapping in expert_params_mapping:
445
+ param_name, weight_name, expert_id, shard_id = mapping
446
+ if weight_name not in name:
447
+ continue
448
+ name = name.replace(weight_name, param_name)
449
+ param = params_dict[name]
450
+ weight_loader = param.weight_loader
451
+ weight_loader(
452
+ param,
453
+ loaded_weight,
454
+ weight_name,
455
+ shard_id=shard_id,
456
+ expert_id=expert_id,
457
+ )
458
+ break
459
+ else:
460
+ # Skip loading extra bias for GPTQ models.
461
+ if name.endswith(".bias") and name not in params_dict:
462
+ continue
463
+ if name not in params_dict:
464
+ continue
465
+
466
+ param = params_dict[name]
467
+ weight_loader = getattr(
468
+ param, "weight_loader", default_weight_loader
469
+ )
470
+ weight_loader(param, loaded_weight)
471
+
472
+
473
+ EntryClass = Qwen2MoeForCausalLM