sglang 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. sglang/__init__.py +5 -1
  2. sglang/api.py +8 -3
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +148 -12
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/global_config.py +11 -1
  8. sglang/lang/chat_template.py +9 -2
  9. sglang/lang/interpreter.py +161 -81
  10. sglang/lang/ir.py +29 -11
  11. sglang/lang/tracer.py +1 -1
  12. sglang/launch_server.py +1 -2
  13. sglang/launch_server_llavavid.py +31 -0
  14. sglang/srt/constrained/fsm_cache.py +3 -0
  15. sglang/srt/flush_cache.py +16 -0
  16. sglang/srt/hf_transformers_utils.py +83 -2
  17. sglang/srt/layers/extend_attention.py +17 -0
  18. sglang/srt/layers/fused_moe.py +485 -0
  19. sglang/srt/layers/logits_processor.py +12 -7
  20. sglang/srt/layers/radix_attention.py +10 -3
  21. sglang/srt/layers/token_attention.py +16 -1
  22. sglang/srt/managers/controller/dp_worker.py +110 -0
  23. sglang/srt/managers/controller/infer_batch.py +619 -0
  24. sglang/srt/managers/controller/manager_multi.py +191 -0
  25. sglang/srt/managers/controller/manager_single.py +97 -0
  26. sglang/srt/managers/controller/model_runner.py +462 -0
  27. sglang/srt/managers/controller/radix_cache.py +267 -0
  28. sglang/srt/managers/controller/schedule_heuristic.py +59 -0
  29. sglang/srt/managers/controller/tp_worker.py +791 -0
  30. sglang/srt/managers/detokenizer_manager.py +45 -45
  31. sglang/srt/managers/io_struct.py +26 -10
  32. sglang/srt/managers/router/infer_batch.py +130 -74
  33. sglang/srt/managers/router/manager.py +7 -9
  34. sglang/srt/managers/router/model_rpc.py +224 -135
  35. sglang/srt/managers/router/model_runner.py +94 -107
  36. sglang/srt/managers/router/radix_cache.py +54 -18
  37. sglang/srt/managers/router/scheduler.py +23 -34
  38. sglang/srt/managers/tokenizer_manager.py +183 -88
  39. sglang/srt/model_config.py +5 -2
  40. sglang/srt/models/commandr.py +15 -22
  41. sglang/srt/models/dbrx.py +22 -29
  42. sglang/srt/models/gemma.py +14 -24
  43. sglang/srt/models/grok.py +671 -0
  44. sglang/srt/models/llama2.py +24 -23
  45. sglang/srt/models/llava.py +85 -25
  46. sglang/srt/models/llavavid.py +298 -0
  47. sglang/srt/models/mixtral.py +254 -130
  48. sglang/srt/models/mixtral_quant.py +373 -0
  49. sglang/srt/models/qwen.py +28 -25
  50. sglang/srt/models/qwen2.py +17 -22
  51. sglang/srt/models/stablelm.py +21 -26
  52. sglang/srt/models/yivl.py +17 -25
  53. sglang/srt/openai_api_adapter.py +140 -95
  54. sglang/srt/openai_protocol.py +10 -1
  55. sglang/srt/server.py +101 -52
  56. sglang/srt/server_args.py +59 -11
  57. sglang/srt/utils.py +242 -75
  58. sglang/test/test_programs.py +44 -0
  59. sglang/test/test_utils.py +32 -1
  60. sglang/utils.py +95 -26
  61. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/METADATA +23 -13
  62. sglang-0.1.17.dist-info/RECORD +81 -0
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -402
  66. sglang-0.1.15.dist-info/RECORD +0 -69
  67. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
  69. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,373 @@
1
+ # Adapted from
2
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral_quant.py#L1
3
+ """Inference-only Mixtral model."""
4
+ from typing import Iterable, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+ from transformers import MixtralConfig
11
+ from vllm.config import CacheConfig
12
+ from vllm.distributed import (
13
+ get_tensor_model_parallel_rank,
14
+ get_tensor_model_parallel_world_size,
15
+ tensor_model_parallel_all_reduce,
16
+ )
17
+ from vllm.model_executor.layers.layernorm import RMSNorm
18
+ from vllm.model_executor.layers.linear import (
19
+ QKVParallelLinear,
20
+ ReplicatedLinear,
21
+ RowParallelLinear,
22
+ )
23
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
24
+ from vllm.model_executor.layers.rotary_embedding import get_rope
25
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
26
+ ParallelLMHead,
27
+ VocabParallelEmbedding,
28
+ )
29
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
+
31
+
32
+ from sglang.srt.layers.logits_processor import LogitsProcessor
33
+ from sglang.srt.layers.radix_attention import RadixAttention
34
+ from sglang.srt.managers.controller.model_runner import InputMetadata
35
+
36
+
37
+ class MixtralMLP(nn.Module):
38
+ def __init__(
39
+ self,
40
+ num_experts: int,
41
+ hidden_size: int,
42
+ intermediate_size: int,
43
+ quant_config: Optional[QuantizationConfig] = None,
44
+ ) -> None:
45
+ super().__init__()
46
+ self.num_experts = num_experts
47
+ self.ffn_dim = intermediate_size
48
+ self.hidden_dim = hidden_size
49
+
50
+ self.w1 = ReplicatedLinear(
51
+ self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
52
+ )
53
+ self.w2 = ReplicatedLinear(
54
+ self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
55
+ )
56
+ self.w3 = ReplicatedLinear(
57
+ self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
58
+ )
59
+
60
+ # TODO: Use vllm's SiluAndMul
61
+ self.act_fn = nn.SiLU()
62
+
63
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
64
+ w1_out, _ = self.w1(hidden_states)
65
+ w1_out = self.act_fn(w1_out)
66
+ w3_out, _ = self.w3(hidden_states)
67
+ current_hidden_states = w1_out * w3_out
68
+ current_hidden_states, _ = self.w2(current_hidden_states)
69
+ return current_hidden_states
70
+
71
+
72
+ class MixtralMoE(nn.Module):
73
+ def __init__(
74
+ self,
75
+ config: MixtralConfig,
76
+ quant_config: Optional[QuantizationConfig] = None,
77
+ ):
78
+ super().__init__()
79
+ self.config = config
80
+ self.rank = get_tensor_model_parallel_rank()
81
+ self.tp_size = get_tensor_model_parallel_world_size()
82
+ self.num_total_experts = config.num_local_experts
83
+ self.top_k = config.num_experts_per_tok
84
+ if self.tp_size > self.num_total_experts:
85
+ raise ValueError(
86
+ f"Tensor parallel size {self.tp_size} is greater than "
87
+ f"the number of experts {self.num_total_experts}."
88
+ )
89
+ # Split experts equally between ranks
90
+ self.expert_indicies = np.array_split(
91
+ range(self.num_total_experts), self.tp_size
92
+ )[self.rank].tolist()
93
+ if not self.expert_indicies:
94
+ raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
95
+
96
+ self.experts = nn.ModuleList(
97
+ [
98
+ (
99
+ MixtralMLP(
100
+ self.num_total_experts,
101
+ config.hidden_size,
102
+ config.intermediate_size,
103
+ quant_config=quant_config,
104
+ )
105
+ if idx in self.expert_indicies
106
+ else None
107
+ )
108
+ for idx in range(self.num_total_experts)
109
+ ]
110
+ )
111
+ self.gate = ReplicatedLinear(
112
+ config.hidden_size, self.num_total_experts, bias=False, quant_config=None
113
+ )
114
+
115
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
116
+ router_logits, _ = self.gate(hidden_states)
117
+
118
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
119
+ routing_weights, selected_experts = torch.topk(
120
+ routing_weights, self.top_k, dim=-1
121
+ )
122
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
123
+
124
+ final_hidden_states = None
125
+ for expert_idx in self.expert_indicies:
126
+ expert_layer = self.experts[expert_idx]
127
+ expert_mask = selected_experts == expert_idx
128
+ expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True)
129
+
130
+ current_hidden_states = expert_layer(hidden_states).mul_(expert_weights)
131
+ if final_hidden_states is None:
132
+ final_hidden_states = current_hidden_states
133
+ else:
134
+ final_hidden_states.add_(current_hidden_states)
135
+
136
+ return tensor_model_parallel_all_reduce(final_hidden_states)
137
+
138
+
139
+ class MixtralAttention(nn.Module):
140
+ def __init__(
141
+ self,
142
+ hidden_size: int,
143
+ num_heads: int,
144
+ num_kv_heads: int,
145
+ layer_id: int = 0,
146
+ max_position: int = 4096 * 32,
147
+ rope_theta: float = 10000,
148
+ quant_config: Optional[QuantizationConfig] = None,
149
+ sliding_window: Optional[int] = None,
150
+ ) -> None:
151
+ super().__init__()
152
+ self.hidden_size = hidden_size
153
+ tp_size = get_tensor_model_parallel_world_size()
154
+ self.total_num_heads = num_heads
155
+ assert self.total_num_heads % tp_size == 0
156
+ self.num_heads = self.total_num_heads // tp_size
157
+ self.total_num_kv_heads = num_kv_heads
158
+ if self.total_num_kv_heads >= tp_size:
159
+ # Number of KV heads is greater than TP size, so we partition
160
+ # the KV heads across multiple tensor parallel GPUs.
161
+ assert self.total_num_kv_heads % tp_size == 0
162
+ else:
163
+ # Number of KV heads is less than TP size, so we replicate
164
+ # the KV heads across multiple tensor parallel GPUs.
165
+ assert tp_size % self.total_num_kv_heads == 0
166
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
167
+ self.head_dim = hidden_size // self.total_num_heads
168
+ self.q_size = self.num_heads * self.head_dim
169
+ self.kv_size = self.num_kv_heads * self.head_dim
170
+ self.scaling = self.head_dim**-0.5
171
+ self.rope_theta = rope_theta
172
+ self.sliding_window = sliding_window
173
+
174
+ self.qkv_proj = QKVParallelLinear(
175
+ hidden_size,
176
+ self.head_dim,
177
+ self.total_num_heads,
178
+ self.total_num_kv_heads,
179
+ bias=False,
180
+ quant_config=quant_config,
181
+ )
182
+ self.o_proj = RowParallelLinear(
183
+ self.total_num_heads * self.head_dim,
184
+ hidden_size,
185
+ bias=False,
186
+ quant_config=quant_config,
187
+ )
188
+ self.rotary_emb = get_rope(
189
+ self.head_dim,
190
+ rotary_dim=self.head_dim,
191
+ max_position=max_position,
192
+ base=int(self.rope_theta),
193
+ is_neox_style=True,
194
+ )
195
+ self.attn = RadixAttention(
196
+ self.num_heads,
197
+ self.head_dim,
198
+ self.scaling,
199
+ num_kv_heads=self.num_kv_heads,
200
+ layer_id=layer_id,
201
+ )
202
+
203
+ def forward(
204
+ self,
205
+ positions: torch.Tensor,
206
+ hidden_states: torch.Tensor,
207
+ input_metadata: InputMetadata,
208
+ ) -> torch.Tensor:
209
+ qkv, _ = self.qkv_proj(hidden_states)
210
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
211
+ q, k = self.rotary_emb(positions, q, k)
212
+ attn_output = self.attn(q, k, v, input_metadata)
213
+ output, _ = self.o_proj(attn_output)
214
+ return output
215
+
216
+
217
+ class MixtralDecoderLayer(nn.Module):
218
+ def __init__(
219
+ self,
220
+ config: MixtralConfig,
221
+ layer_id: int = 0,
222
+ quant_config: Optional[QuantizationConfig] = None,
223
+ ) -> None:
224
+ super().__init__()
225
+ self.hidden_size = config.hidden_size
226
+ # Requires transformers > 4.32.0
227
+ rope_theta = getattr(config, "rope_theta", 10000)
228
+ self.self_attn = MixtralAttention(
229
+ hidden_size=self.hidden_size,
230
+ num_heads=config.num_attention_heads,
231
+ max_position=config.max_position_embeddings,
232
+ num_kv_heads=config.num_key_value_heads,
233
+ layer_id=layer_id,
234
+ rope_theta=rope_theta,
235
+ sliding_window=config.sliding_window,
236
+ quant_config=quant_config,
237
+ )
238
+ self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config)
239
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
240
+ self.post_attention_layernorm = RMSNorm(
241
+ config.hidden_size, eps=config.rms_norm_eps
242
+ )
243
+
244
+ def forward(
245
+ self,
246
+ positions: torch.Tensor,
247
+ hidden_states: torch.Tensor,
248
+ input_metadata: InputMetadata,
249
+ residual: Optional[torch.Tensor],
250
+ ) -> torch.Tensor:
251
+ # Self Attention
252
+ if residual is None:
253
+ residual = hidden_states
254
+ hidden_states = self.input_layernorm(hidden_states)
255
+ else:
256
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
257
+ hidden_states = self.self_attn(
258
+ positions=positions,
259
+ hidden_states=hidden_states,
260
+ input_metadata=input_metadata,
261
+ )
262
+
263
+ # Fully Connected
264
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
265
+ hidden_states = self.block_sparse_moe(hidden_states)
266
+ return hidden_states, residual
267
+
268
+
269
+ class MixtralModel(nn.Module):
270
+ def __init__(
271
+ self,
272
+ config: MixtralConfig,
273
+ quant_config: Optional[QuantizationConfig] = None,
274
+ ) -> None:
275
+ super().__init__()
276
+ self.padding_idx = config.pad_token_id
277
+ self.vocab_size = config.vocab_size
278
+
279
+ self.embed_tokens = VocabParallelEmbedding(
280
+ config.vocab_size,
281
+ config.hidden_size,
282
+ )
283
+ self.layers = nn.ModuleList(
284
+ [
285
+ MixtralDecoderLayer(config, i, quant_config=quant_config)
286
+ for i in range(config.num_hidden_layers)
287
+ ]
288
+ )
289
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
290
+
291
+ def forward(
292
+ self,
293
+ input_ids: torch.Tensor,
294
+ positions: torch.Tensor,
295
+ input_metadata: InputMetadata,
296
+ input_embeds: torch.Tensor = None,
297
+ ) -> torch.Tensor:
298
+ if input_embeds is None:
299
+ hidden_states = self.embed_tokens(input_ids)
300
+ else:
301
+ hidden_states = input_embeds
302
+ residual = None
303
+ for i in range(len(self.layers)):
304
+ layer = self.layers[i]
305
+ hidden_states, residual = layer(
306
+ positions, hidden_states, input_metadata, residual
307
+ )
308
+ hidden_states, _ = self.norm(hidden_states, residual)
309
+ return hidden_states
310
+
311
+
312
+ class QuantMixtralForCausalLM(nn.Module):
313
+ def __init__(
314
+ self,
315
+ config: MixtralConfig,
316
+ quant_config: Optional[QuantizationConfig] = None,
317
+ cache_config: Optional[CacheConfig] = None,
318
+ ) -> None:
319
+ super().__init__()
320
+ self.config = config
321
+ self.quant_config = quant_config
322
+ self.model = MixtralModel(config, quant_config=quant_config)
323
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
324
+ self.logits_processor = LogitsProcessor(config)
325
+
326
+ def forward(
327
+ self,
328
+ input_ids: torch.Tensor,
329
+ positions: torch.Tensor,
330
+ input_metadata: InputMetadata,
331
+ input_embeds: torch.Tensor = None,
332
+ ) -> torch.Tensor:
333
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
334
+ return self.logits_processor(
335
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
336
+ )
337
+
338
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
339
+ stacked_params_mapping = [
340
+ # (param_name, shard_name, shard_id)
341
+ ("qkv_proj", "q_proj", "q"),
342
+ ("qkv_proj", "k_proj", "k"),
343
+ ("qkv_proj", "v_proj", "v"),
344
+ ]
345
+
346
+ params_dict = dict(self.named_parameters())
347
+ for name, loaded_weight in weights:
348
+ if "rotary_emb.inv_freq" in name:
349
+ continue
350
+ for param_name, weight_name, shard_id in stacked_params_mapping:
351
+ if weight_name not in name:
352
+ continue
353
+ name = name.replace(weight_name, param_name)
354
+ # Skip loading extra bias for GPTQ models.
355
+ if name.endswith(".bias") and name not in params_dict:
356
+ continue
357
+ param = params_dict[name]
358
+ weight_loader = param.weight_loader
359
+ weight_loader(param, loaded_weight, shard_id)
360
+ break
361
+ else:
362
+ # Skip loading extra bias for GPTQ models.
363
+ if name.endswith(".bias") and name not in params_dict:
364
+ continue
365
+ # Skip experts that are not assigned to this worker.
366
+ if "block_sparse_moe.experts." in name and name not in params_dict:
367
+ continue
368
+ param = params_dict[name]
369
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
370
+ weight_loader(param, loaded_weight)
371
+
372
+
373
+ EntryClass = QuantMixtralForCausalLM
sglang/srt/models/qwen.py CHANGED
@@ -1,8 +1,12 @@
1
- from typing import Any, Dict, Optional
1
+ # Adapted from
2
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
3
+ from typing import Any, Dict, Optional, Iterable, Tuple
2
4
 
3
5
  import torch
4
6
  from torch import nn
5
7
  from transformers import PretrainedConfig
8
+ from vllm.config import CacheConfig
9
+ from vllm.distributed import get_tensor_model_parallel_world_size
6
10
  from vllm.model_executor.layers.activation import SiluAndMul
7
11
  from vllm.model_executor.layers.layernorm import RMSNorm
8
12
  from vllm.model_executor.layers.linear import (
@@ -10,24 +14,17 @@ from vllm.model_executor.layers.linear import (
10
14
  QKVParallelLinear,
11
15
  RowParallelLinear,
12
16
  )
13
- from vllm.model_executor.layers.quantization.base_config import (
14
- QuantizationConfig)
17
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
15
18
  from vllm.model_executor.layers.rotary_embedding import get_rope
16
19
  from vllm.model_executor.layers.vocab_parallel_embedding import (
17
20
  ParallelLMHead,
18
21
  VocabParallelEmbedding,
19
22
  )
20
- from vllm.distributed import (
21
- get_tensor_model_parallel_world_size,
22
- )
23
- from sglang.srt.weight_utils import (
24
- default_weight_loader,
25
- hf_model_weights_iterator,
26
- )
23
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
27
24
 
28
25
  from sglang.srt.layers.logits_processor import LogitsProcessor
29
26
  from sglang.srt.layers.radix_attention import RadixAttention
30
- from sglang.srt.managers.router.model_runner import InputMetadata
27
+ from sglang.srt.managers.controller.model_runner import InputMetadata
31
28
 
32
29
 
33
30
  class QWenMLP(nn.Module):
@@ -132,7 +129,12 @@ class QWenAttention(nn.Module):
132
129
 
133
130
 
134
131
  class QWenBlock(nn.Module):
135
- def __init__(self, config: PretrainedConfig, layer_id, quant_config: Optional[QuantizationConfig] = None,):
132
+ def __init__(
133
+ self,
134
+ config: PretrainedConfig,
135
+ layer_id,
136
+ quant_config: Optional[QuantizationConfig] = None,
137
+ ):
136
138
  super().__init__()
137
139
  self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
138
140
 
@@ -181,7 +183,11 @@ class QWenBlock(nn.Module):
181
183
 
182
184
 
183
185
  class QWenModel(nn.Module):
184
- def __init__(self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,):
186
+ def __init__(
187
+ self,
188
+ config: PretrainedConfig,
189
+ quant_config: Optional[QuantizationConfig] = None,
190
+ ):
185
191
  super().__init__()
186
192
  self.config = config
187
193
  self.vocab_size = config.vocab_size
@@ -218,7 +224,12 @@ class QWenModel(nn.Module):
218
224
 
219
225
 
220
226
  class QWenLMHeadModel(nn.Module):
221
- def __init__(self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,):
227
+ def __init__(
228
+ self,
229
+ config: PretrainedConfig,
230
+ quant_config: Optional[QuantizationConfig] = None,
231
+ cache_config: Optional[CacheConfig] = None,
232
+ ):
222
233
  super().__init__()
223
234
  self.config = config
224
235
  self.transformer = QWenModel(config, quant_config=quant_config)
@@ -238,22 +249,14 @@ class QWenLMHeadModel(nn.Module):
238
249
  )
239
250
  return next_tokens
240
251
 
241
- def load_weights(
242
- self,
243
- model_name_or_path: str,
244
- cache_dir: Optional[str] = None,
245
- load_format: str = "auto",
246
- revision: Optional[str] = None,
247
- ):
252
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
248
253
  stacked_params_mapping = [
249
254
  # (param_name, shard_name, shard_id)
250
255
  ("gate_up_proj", "w2", 0),
251
256
  ("gate_up_proj", "w1", 1),
252
257
  ]
253
258
  params_dict = dict(self.named_parameters())
254
- for name, loaded_weight in hf_model_weights_iterator(
255
- model_name_or_path, cache_dir, load_format, revision
256
- ):
259
+ for name, loaded_weight in weights:
257
260
  if "rotary_emb.inv_freq" in name:
258
261
  continue
259
262
  for param_name, weight_name, shard_id in stacked_params_mapping:
@@ -276,4 +279,4 @@ class QWenLMHeadModel(nn.Module):
276
279
  weight_loader(param, loaded_weight)
277
280
 
278
281
 
279
- EntryClass = QWenLMHeadModel
282
+ EntryClass = QWenLMHeadModel
@@ -1,10 +1,12 @@
1
1
  # Adapted from llama2.py
2
2
  # Modify details for the adaptation of Qwen2 model.
3
3
  """Inference-only Qwen2 model compatible with HuggingFace weights."""
4
- from typing import Any, Dict, Optional, Tuple
4
+ from typing import Any, Dict, Optional, Tuple, Iterable
5
5
 
6
6
  import torch
7
7
  from torch import nn
8
+ from vllm.config import CacheConfig
9
+ from vllm.distributed import get_tensor_model_parallel_world_size
8
10
  from vllm.model_executor.layers.activation import SiluAndMul
9
11
  from vllm.model_executor.layers.layernorm import RMSNorm
10
12
  from vllm.model_executor.layers.linear import (
@@ -12,24 +14,17 @@ from vllm.model_executor.layers.linear import (
12
14
  QKVParallelLinear,
13
15
  RowParallelLinear,
14
16
  )
15
- from vllm.model_executor.layers.quantization.base_config import (
16
- QuantizationConfig)
17
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
17
18
  from vllm.model_executor.layers.rotary_embedding import get_rope
18
19
  from vllm.model_executor.layers.vocab_parallel_embedding import (
19
20
  ParallelLMHead,
20
21
  VocabParallelEmbedding,
21
22
  )
22
- from vllm.distributed import (
23
- get_tensor_model_parallel_world_size,
24
- )
25
- from sglang.srt.weight_utils import (
26
- default_weight_loader,
27
- hf_model_weights_iterator,
28
- )
23
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
29
24
 
30
25
  from sglang.srt.layers.logits_processor import LogitsProcessor
31
26
  from sglang.srt.layers.radix_attention import RadixAttention
32
- from sglang.srt.managers.router.model_runner import InputMetadata
27
+ from sglang.srt.managers.controller.model_runner import InputMetadata
33
28
 
34
29
  Qwen2Config = None
35
30
 
@@ -50,7 +45,10 @@ class Qwen2MLP(nn.Module):
50
45
  quant_config=quant_config,
51
46
  )
52
47
  self.down_proj = RowParallelLinear(
53
- intermediate_size, hidden_size, bias=False, quant_config=quant_config,
48
+ intermediate_size,
49
+ hidden_size,
50
+ bias=False,
51
+ quant_config=quant_config,
54
52
  )
55
53
  if hidden_act != "silu":
56
54
  raise ValueError(
@@ -254,6 +252,7 @@ class Qwen2ForCausalLM(nn.Module):
254
252
  self,
255
253
  config: Qwen2Config,
256
254
  quant_config: Optional[QuantizationConfig] = None,
255
+ cache_config: Optional[CacheConfig] = None,
257
256
  ) -> None:
258
257
  super().__init__()
259
258
  self.config = config
@@ -274,13 +273,7 @@ class Qwen2ForCausalLM(nn.Module):
274
273
  input_ids, hidden_states, self.lm_head.weight, input_metadata
275
274
  )
276
275
 
277
- def load_weights(
278
- self,
279
- model_name_or_path: str,
280
- cache_dir: Optional[str] = None,
281
- load_format: str = "auto",
282
- revision: Optional[str] = None,
283
- ):
276
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
284
277
  stacked_params_mapping = [
285
278
  # (param_name, shard_name, shard_id)
286
279
  ("qkv_proj", "q_proj", "q"),
@@ -290,9 +283,7 @@ class Qwen2ForCausalLM(nn.Module):
290
283
  ("gate_up_proj", "up_proj", 1),
291
284
  ]
292
285
  params_dict = dict(self.named_parameters())
293
- for name, loaded_weight in hf_model_weights_iterator(
294
- model_name_or_path, cache_dir, load_format, revision
295
- ):
286
+ for name, loaded_weight in weights:
296
287
  if "rotary_emb.inv_freq" in name or "projector" in name:
297
288
  continue
298
289
  if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
@@ -306,6 +297,8 @@ class Qwen2ForCausalLM(nn.Module):
306
297
  # Skip loading extra bias for GPTQ models.
307
298
  if name.endswith(".bias") and name not in params_dict:
308
299
  continue
300
+ if name.startswith("model.vision_tower") and name not in params_dict:
301
+ continue
309
302
  param = params_dict[name]
310
303
  weight_loader = param.weight_loader
311
304
  weight_loader(param, loaded_weight, shard_id)
@@ -314,6 +307,8 @@ class Qwen2ForCausalLM(nn.Module):
314
307
  # Skip loading extra bias for GPTQ models.
315
308
  if name.endswith(".bias") and name not in params_dict:
316
309
  continue
310
+ if name.startswith("model.vision_tower") and name not in params_dict:
311
+ continue
317
312
  param = params_dict[name]
318
313
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
319
314
  weight_loader(param, loaded_weight)
@@ -1,41 +1,38 @@
1
- # This code is based on:
2
- # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/stablelm.py
1
+ # Adapted from:
2
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/stablelm.py#L1
3
3
  """Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b)
4
4
  model compatible with HuggingFace weights."""
5
- from typing import Optional, Tuple
5
+ from typing import Optional, Tuple, Iterable
6
6
 
7
7
  import torch
8
8
  from torch import nn
9
9
  from transformers import PretrainedConfig
10
+ from vllm.config import CacheConfig
11
+ from vllm.distributed import get_tensor_model_parallel_world_size
10
12
  from vllm.model_executor.layers.activation import SiluAndMul
11
13
  from vllm.model_executor.layers.linear import (
12
14
  MergedColumnParallelLinear,
13
15
  QKVParallelLinear,
14
16
  RowParallelLinear,
15
17
  )
16
- from vllm.model_executor.layers.quantization.base_config import (
17
- QuantizationConfig)
18
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
18
19
  from vllm.model_executor.layers.rotary_embedding import get_rope
19
20
  from vllm.model_executor.layers.vocab_parallel_embedding import (
20
21
  ParallelLMHead,
21
22
  VocabParallelEmbedding,
22
23
  )
23
- from vllm.distributed import (
24
- get_tensor_model_parallel_world_size,
25
- )
26
- from sglang.srt.weight_utils import (
27
- default_weight_loader,
28
- hf_model_weights_iterator,
29
- )
24
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
25
 
31
26
  from sglang.srt.layers.logits_processor import LogitsProcessor
32
27
  from sglang.srt.layers.radix_attention import RadixAttention
33
- from sglang.srt.managers.router.model_runner import InputMetadata
28
+ from sglang.srt.managers.controller.model_runner import InputMetadata
34
29
 
35
30
 
36
31
  class StablelmMLP(nn.Module):
37
32
  def __init__(
38
- self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,
33
+ self,
34
+ config: PretrainedConfig,
35
+ quant_config: Optional[QuantizationConfig] = None,
39
36
  ) -> None:
40
37
  super().__init__()
41
38
  self.config = config
@@ -48,7 +45,10 @@ class StablelmMLP(nn.Module):
48
45
  quant_config=quant_config,
49
46
  )
50
47
  self.down_proj = RowParallelLinear(
51
- config.intermediate_size, config.hidden_size, bias=False, quant_config=quant_config,
48
+ config.intermediate_size,
49
+ config.hidden_size,
50
+ bias=False,
51
+ quant_config=quant_config,
52
52
  )
53
53
  self.act_fn = SiluAndMul()
54
54
 
@@ -181,7 +181,9 @@ class StablelmDecoderLayer(nn.Module):
181
181
 
182
182
  class StableLMEpochModel(nn.Module):
183
183
  def __init__(
184
- self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,
184
+ self,
185
+ config: PretrainedConfig,
186
+ quant_config: Optional[QuantizationConfig] = None,
185
187
  ) -> None:
186
188
  super().__init__()
187
189
  self.embed_tokens = VocabParallelEmbedding(
@@ -224,6 +226,7 @@ class StableLmForCausalLM(nn.Module):
224
226
  self,
225
227
  config: PretrainedConfig,
226
228
  quant_config: Optional[QuantizationConfig] = None,
229
+ cache_config: Optional[CacheConfig] = None,
227
230
  ) -> None:
228
231
  super().__init__()
229
232
  self.config = config
@@ -244,13 +247,7 @@ class StableLmForCausalLM(nn.Module):
244
247
  input_ids, hidden_states, self.lm_head.weight, input_metadata
245
248
  )
246
249
 
247
- def load_weights(
248
- self,
249
- model_name_or_path: str,
250
- cache_dir: Optional[str] = None,
251
- load_format: str = "auto",
252
- revision: Optional[str] = None,
253
- ):
250
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
254
251
  stacked_params_mapping = [
255
252
  # (param_name, shard_name, shard_id)
256
253
  ("qkv_proj", "q_proj", "q"),
@@ -260,9 +257,7 @@ class StableLmForCausalLM(nn.Module):
260
257
  ("gate_up_proj", "up_proj", 1),
261
258
  ]
262
259
  params_dict = dict(self.named_parameters())
263
- for name, loaded_weight in hf_model_weights_iterator(
264
- model_name_or_path, cache_dir, load_format, revision
265
- ):
260
+ for name, loaded_weight in weights:
266
261
  if "rotary_emb.inv_freq" in name:
267
262
  continue
268
263
  if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: