sglang 0.1.16__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +3 -3
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +148 -12
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/global_config.py +8 -1
  8. sglang/lang/interpreter.py +114 -67
  9. sglang/lang/ir.py +17 -2
  10. sglang/srt/constrained/fsm_cache.py +3 -0
  11. sglang/srt/flush_cache.py +1 -1
  12. sglang/srt/hf_transformers_utils.py +75 -1
  13. sglang/srt/layers/extend_attention.py +17 -0
  14. sglang/srt/layers/fused_moe.py +485 -0
  15. sglang/srt/layers/logits_processor.py +12 -7
  16. sglang/srt/layers/radix_attention.py +10 -3
  17. sglang/srt/layers/token_attention.py +16 -1
  18. sglang/srt/managers/controller/dp_worker.py +110 -0
  19. sglang/srt/managers/controller/infer_batch.py +619 -0
  20. sglang/srt/managers/controller/manager_multi.py +191 -0
  21. sglang/srt/managers/controller/manager_single.py +97 -0
  22. sglang/srt/managers/controller/model_runner.py +462 -0
  23. sglang/srt/managers/controller/radix_cache.py +267 -0
  24. sglang/srt/managers/controller/schedule_heuristic.py +59 -0
  25. sglang/srt/managers/controller/tp_worker.py +791 -0
  26. sglang/srt/managers/detokenizer_manager.py +45 -45
  27. sglang/srt/managers/io_struct.py +15 -11
  28. sglang/srt/managers/router/infer_batch.py +103 -59
  29. sglang/srt/managers/router/manager.py +1 -1
  30. sglang/srt/managers/router/model_rpc.py +175 -122
  31. sglang/srt/managers/router/model_runner.py +91 -104
  32. sglang/srt/managers/router/radix_cache.py +7 -1
  33. sglang/srt/managers/router/scheduler.py +6 -6
  34. sglang/srt/managers/tokenizer_manager.py +152 -89
  35. sglang/srt/model_config.py +4 -5
  36. sglang/srt/models/commandr.py +10 -13
  37. sglang/srt/models/dbrx.py +9 -15
  38. sglang/srt/models/gemma.py +8 -15
  39. sglang/srt/models/grok.py +671 -0
  40. sglang/srt/models/llama2.py +19 -15
  41. sglang/srt/models/llava.py +84 -20
  42. sglang/srt/models/llavavid.py +11 -20
  43. sglang/srt/models/mixtral.py +248 -118
  44. sglang/srt/models/mixtral_quant.py +373 -0
  45. sglang/srt/models/qwen.py +9 -13
  46. sglang/srt/models/qwen2.py +11 -13
  47. sglang/srt/models/stablelm.py +9 -15
  48. sglang/srt/models/yivl.py +17 -22
  49. sglang/srt/openai_api_adapter.py +140 -95
  50. sglang/srt/openai_protocol.py +10 -1
  51. sglang/srt/server.py +77 -42
  52. sglang/srt/server_args.py +51 -6
  53. sglang/srt/utils.py +124 -66
  54. sglang/test/test_programs.py +44 -0
  55. sglang/test/test_utils.py +32 -1
  56. sglang/utils.py +22 -4
  57. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/METADATA +15 -9
  58. sglang-0.1.17.dist-info/RECORD +81 -0
  59. sglang/srt/backend_config.py +0 -13
  60. sglang/srt/models/dbrx_config.py +0 -281
  61. sglang/srt/weight_utils.py +0 -417
  62. sglang-0.1.16.dist-info/RECORD +0 -72
  63. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
  64. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
  65. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,373 @@
1
+ # Adapted from
2
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral_quant.py#L1
3
+ """Inference-only Mixtral model."""
4
+ from typing import Iterable, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+ from transformers import MixtralConfig
11
+ from vllm.config import CacheConfig
12
+ from vllm.distributed import (
13
+ get_tensor_model_parallel_rank,
14
+ get_tensor_model_parallel_world_size,
15
+ tensor_model_parallel_all_reduce,
16
+ )
17
+ from vllm.model_executor.layers.layernorm import RMSNorm
18
+ from vllm.model_executor.layers.linear import (
19
+ QKVParallelLinear,
20
+ ReplicatedLinear,
21
+ RowParallelLinear,
22
+ )
23
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
24
+ from vllm.model_executor.layers.rotary_embedding import get_rope
25
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
26
+ ParallelLMHead,
27
+ VocabParallelEmbedding,
28
+ )
29
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
+
31
+
32
+ from sglang.srt.layers.logits_processor import LogitsProcessor
33
+ from sglang.srt.layers.radix_attention import RadixAttention
34
+ from sglang.srt.managers.controller.model_runner import InputMetadata
35
+
36
+
37
+ class MixtralMLP(nn.Module):
38
+ def __init__(
39
+ self,
40
+ num_experts: int,
41
+ hidden_size: int,
42
+ intermediate_size: int,
43
+ quant_config: Optional[QuantizationConfig] = None,
44
+ ) -> None:
45
+ super().__init__()
46
+ self.num_experts = num_experts
47
+ self.ffn_dim = intermediate_size
48
+ self.hidden_dim = hidden_size
49
+
50
+ self.w1 = ReplicatedLinear(
51
+ self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
52
+ )
53
+ self.w2 = ReplicatedLinear(
54
+ self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
55
+ )
56
+ self.w3 = ReplicatedLinear(
57
+ self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
58
+ )
59
+
60
+ # TODO: Use vllm's SiluAndMul
61
+ self.act_fn = nn.SiLU()
62
+
63
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
64
+ w1_out, _ = self.w1(hidden_states)
65
+ w1_out = self.act_fn(w1_out)
66
+ w3_out, _ = self.w3(hidden_states)
67
+ current_hidden_states = w1_out * w3_out
68
+ current_hidden_states, _ = self.w2(current_hidden_states)
69
+ return current_hidden_states
70
+
71
+
72
+ class MixtralMoE(nn.Module):
73
+ def __init__(
74
+ self,
75
+ config: MixtralConfig,
76
+ quant_config: Optional[QuantizationConfig] = None,
77
+ ):
78
+ super().__init__()
79
+ self.config = config
80
+ self.rank = get_tensor_model_parallel_rank()
81
+ self.tp_size = get_tensor_model_parallel_world_size()
82
+ self.num_total_experts = config.num_local_experts
83
+ self.top_k = config.num_experts_per_tok
84
+ if self.tp_size > self.num_total_experts:
85
+ raise ValueError(
86
+ f"Tensor parallel size {self.tp_size} is greater than "
87
+ f"the number of experts {self.num_total_experts}."
88
+ )
89
+ # Split experts equally between ranks
90
+ self.expert_indicies = np.array_split(
91
+ range(self.num_total_experts), self.tp_size
92
+ )[self.rank].tolist()
93
+ if not self.expert_indicies:
94
+ raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
95
+
96
+ self.experts = nn.ModuleList(
97
+ [
98
+ (
99
+ MixtralMLP(
100
+ self.num_total_experts,
101
+ config.hidden_size,
102
+ config.intermediate_size,
103
+ quant_config=quant_config,
104
+ )
105
+ if idx in self.expert_indicies
106
+ else None
107
+ )
108
+ for idx in range(self.num_total_experts)
109
+ ]
110
+ )
111
+ self.gate = ReplicatedLinear(
112
+ config.hidden_size, self.num_total_experts, bias=False, quant_config=None
113
+ )
114
+
115
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
116
+ router_logits, _ = self.gate(hidden_states)
117
+
118
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
119
+ routing_weights, selected_experts = torch.topk(
120
+ routing_weights, self.top_k, dim=-1
121
+ )
122
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
123
+
124
+ final_hidden_states = None
125
+ for expert_idx in self.expert_indicies:
126
+ expert_layer = self.experts[expert_idx]
127
+ expert_mask = selected_experts == expert_idx
128
+ expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True)
129
+
130
+ current_hidden_states = expert_layer(hidden_states).mul_(expert_weights)
131
+ if final_hidden_states is None:
132
+ final_hidden_states = current_hidden_states
133
+ else:
134
+ final_hidden_states.add_(current_hidden_states)
135
+
136
+ return tensor_model_parallel_all_reduce(final_hidden_states)
137
+
138
+
139
+ class MixtralAttention(nn.Module):
140
+ def __init__(
141
+ self,
142
+ hidden_size: int,
143
+ num_heads: int,
144
+ num_kv_heads: int,
145
+ layer_id: int = 0,
146
+ max_position: int = 4096 * 32,
147
+ rope_theta: float = 10000,
148
+ quant_config: Optional[QuantizationConfig] = None,
149
+ sliding_window: Optional[int] = None,
150
+ ) -> None:
151
+ super().__init__()
152
+ self.hidden_size = hidden_size
153
+ tp_size = get_tensor_model_parallel_world_size()
154
+ self.total_num_heads = num_heads
155
+ assert self.total_num_heads % tp_size == 0
156
+ self.num_heads = self.total_num_heads // tp_size
157
+ self.total_num_kv_heads = num_kv_heads
158
+ if self.total_num_kv_heads >= tp_size:
159
+ # Number of KV heads is greater than TP size, so we partition
160
+ # the KV heads across multiple tensor parallel GPUs.
161
+ assert self.total_num_kv_heads % tp_size == 0
162
+ else:
163
+ # Number of KV heads is less than TP size, so we replicate
164
+ # the KV heads across multiple tensor parallel GPUs.
165
+ assert tp_size % self.total_num_kv_heads == 0
166
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
167
+ self.head_dim = hidden_size // self.total_num_heads
168
+ self.q_size = self.num_heads * self.head_dim
169
+ self.kv_size = self.num_kv_heads * self.head_dim
170
+ self.scaling = self.head_dim**-0.5
171
+ self.rope_theta = rope_theta
172
+ self.sliding_window = sliding_window
173
+
174
+ self.qkv_proj = QKVParallelLinear(
175
+ hidden_size,
176
+ self.head_dim,
177
+ self.total_num_heads,
178
+ self.total_num_kv_heads,
179
+ bias=False,
180
+ quant_config=quant_config,
181
+ )
182
+ self.o_proj = RowParallelLinear(
183
+ self.total_num_heads * self.head_dim,
184
+ hidden_size,
185
+ bias=False,
186
+ quant_config=quant_config,
187
+ )
188
+ self.rotary_emb = get_rope(
189
+ self.head_dim,
190
+ rotary_dim=self.head_dim,
191
+ max_position=max_position,
192
+ base=int(self.rope_theta),
193
+ is_neox_style=True,
194
+ )
195
+ self.attn = RadixAttention(
196
+ self.num_heads,
197
+ self.head_dim,
198
+ self.scaling,
199
+ num_kv_heads=self.num_kv_heads,
200
+ layer_id=layer_id,
201
+ )
202
+
203
+ def forward(
204
+ self,
205
+ positions: torch.Tensor,
206
+ hidden_states: torch.Tensor,
207
+ input_metadata: InputMetadata,
208
+ ) -> torch.Tensor:
209
+ qkv, _ = self.qkv_proj(hidden_states)
210
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
211
+ q, k = self.rotary_emb(positions, q, k)
212
+ attn_output = self.attn(q, k, v, input_metadata)
213
+ output, _ = self.o_proj(attn_output)
214
+ return output
215
+
216
+
217
+ class MixtralDecoderLayer(nn.Module):
218
+ def __init__(
219
+ self,
220
+ config: MixtralConfig,
221
+ layer_id: int = 0,
222
+ quant_config: Optional[QuantizationConfig] = None,
223
+ ) -> None:
224
+ super().__init__()
225
+ self.hidden_size = config.hidden_size
226
+ # Requires transformers > 4.32.0
227
+ rope_theta = getattr(config, "rope_theta", 10000)
228
+ self.self_attn = MixtralAttention(
229
+ hidden_size=self.hidden_size,
230
+ num_heads=config.num_attention_heads,
231
+ max_position=config.max_position_embeddings,
232
+ num_kv_heads=config.num_key_value_heads,
233
+ layer_id=layer_id,
234
+ rope_theta=rope_theta,
235
+ sliding_window=config.sliding_window,
236
+ quant_config=quant_config,
237
+ )
238
+ self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config)
239
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
240
+ self.post_attention_layernorm = RMSNorm(
241
+ config.hidden_size, eps=config.rms_norm_eps
242
+ )
243
+
244
+ def forward(
245
+ self,
246
+ positions: torch.Tensor,
247
+ hidden_states: torch.Tensor,
248
+ input_metadata: InputMetadata,
249
+ residual: Optional[torch.Tensor],
250
+ ) -> torch.Tensor:
251
+ # Self Attention
252
+ if residual is None:
253
+ residual = hidden_states
254
+ hidden_states = self.input_layernorm(hidden_states)
255
+ else:
256
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
257
+ hidden_states = self.self_attn(
258
+ positions=positions,
259
+ hidden_states=hidden_states,
260
+ input_metadata=input_metadata,
261
+ )
262
+
263
+ # Fully Connected
264
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
265
+ hidden_states = self.block_sparse_moe(hidden_states)
266
+ return hidden_states, residual
267
+
268
+
269
+ class MixtralModel(nn.Module):
270
+ def __init__(
271
+ self,
272
+ config: MixtralConfig,
273
+ quant_config: Optional[QuantizationConfig] = None,
274
+ ) -> None:
275
+ super().__init__()
276
+ self.padding_idx = config.pad_token_id
277
+ self.vocab_size = config.vocab_size
278
+
279
+ self.embed_tokens = VocabParallelEmbedding(
280
+ config.vocab_size,
281
+ config.hidden_size,
282
+ )
283
+ self.layers = nn.ModuleList(
284
+ [
285
+ MixtralDecoderLayer(config, i, quant_config=quant_config)
286
+ for i in range(config.num_hidden_layers)
287
+ ]
288
+ )
289
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
290
+
291
+ def forward(
292
+ self,
293
+ input_ids: torch.Tensor,
294
+ positions: torch.Tensor,
295
+ input_metadata: InputMetadata,
296
+ input_embeds: torch.Tensor = None,
297
+ ) -> torch.Tensor:
298
+ if input_embeds is None:
299
+ hidden_states = self.embed_tokens(input_ids)
300
+ else:
301
+ hidden_states = input_embeds
302
+ residual = None
303
+ for i in range(len(self.layers)):
304
+ layer = self.layers[i]
305
+ hidden_states, residual = layer(
306
+ positions, hidden_states, input_metadata, residual
307
+ )
308
+ hidden_states, _ = self.norm(hidden_states, residual)
309
+ return hidden_states
310
+
311
+
312
+ class QuantMixtralForCausalLM(nn.Module):
313
+ def __init__(
314
+ self,
315
+ config: MixtralConfig,
316
+ quant_config: Optional[QuantizationConfig] = None,
317
+ cache_config: Optional[CacheConfig] = None,
318
+ ) -> None:
319
+ super().__init__()
320
+ self.config = config
321
+ self.quant_config = quant_config
322
+ self.model = MixtralModel(config, quant_config=quant_config)
323
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
324
+ self.logits_processor = LogitsProcessor(config)
325
+
326
+ def forward(
327
+ self,
328
+ input_ids: torch.Tensor,
329
+ positions: torch.Tensor,
330
+ input_metadata: InputMetadata,
331
+ input_embeds: torch.Tensor = None,
332
+ ) -> torch.Tensor:
333
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
334
+ return self.logits_processor(
335
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
336
+ )
337
+
338
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
339
+ stacked_params_mapping = [
340
+ # (param_name, shard_name, shard_id)
341
+ ("qkv_proj", "q_proj", "q"),
342
+ ("qkv_proj", "k_proj", "k"),
343
+ ("qkv_proj", "v_proj", "v"),
344
+ ]
345
+
346
+ params_dict = dict(self.named_parameters())
347
+ for name, loaded_weight in weights:
348
+ if "rotary_emb.inv_freq" in name:
349
+ continue
350
+ for param_name, weight_name, shard_id in stacked_params_mapping:
351
+ if weight_name not in name:
352
+ continue
353
+ name = name.replace(weight_name, param_name)
354
+ # Skip loading extra bias for GPTQ models.
355
+ if name.endswith(".bias") and name not in params_dict:
356
+ continue
357
+ param = params_dict[name]
358
+ weight_loader = param.weight_loader
359
+ weight_loader(param, loaded_weight, shard_id)
360
+ break
361
+ else:
362
+ # Skip loading extra bias for GPTQ models.
363
+ if name.endswith(".bias") and name not in params_dict:
364
+ continue
365
+ # Skip experts that are not assigned to this worker.
366
+ if "block_sparse_moe.experts." in name and name not in params_dict:
367
+ continue
368
+ param = params_dict[name]
369
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
370
+ weight_loader(param, loaded_weight)
371
+
372
+
373
+ EntryClass = QuantMixtralForCausalLM
sglang/srt/models/qwen.py CHANGED
@@ -1,8 +1,11 @@
1
- from typing import Any, Dict, Optional
1
+ # Adapted from
2
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
3
+ from typing import Any, Dict, Optional, Iterable, Tuple
2
4
 
3
5
  import torch
4
6
  from torch import nn
5
7
  from transformers import PretrainedConfig
8
+ from vllm.config import CacheConfig
6
9
  from vllm.distributed import get_tensor_model_parallel_world_size
7
10
  from vllm.model_executor.layers.activation import SiluAndMul
8
11
  from vllm.model_executor.layers.layernorm import RMSNorm
@@ -17,11 +20,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
17
20
  ParallelLMHead,
18
21
  VocabParallelEmbedding,
19
22
  )
23
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
20
24
 
21
25
  from sglang.srt.layers.logits_processor import LogitsProcessor
22
26
  from sglang.srt.layers.radix_attention import RadixAttention
23
- from sglang.srt.managers.router.model_runner import InputMetadata
24
- from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
27
+ from sglang.srt.managers.controller.model_runner import InputMetadata
25
28
 
26
29
 
27
30
  class QWenMLP(nn.Module):
@@ -225,6 +228,7 @@ class QWenLMHeadModel(nn.Module):
225
228
  self,
226
229
  config: PretrainedConfig,
227
230
  quant_config: Optional[QuantizationConfig] = None,
231
+ cache_config: Optional[CacheConfig] = None,
228
232
  ):
229
233
  super().__init__()
230
234
  self.config = config
@@ -245,22 +249,14 @@ class QWenLMHeadModel(nn.Module):
245
249
  )
246
250
  return next_tokens
247
251
 
248
- def load_weights(
249
- self,
250
- model_name_or_path: str,
251
- cache_dir: Optional[str] = None,
252
- load_format: str = "auto",
253
- revision: Optional[str] = None,
254
- ):
252
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
255
253
  stacked_params_mapping = [
256
254
  # (param_name, shard_name, shard_id)
257
255
  ("gate_up_proj", "w2", 0),
258
256
  ("gate_up_proj", "w1", 1),
259
257
  ]
260
258
  params_dict = dict(self.named_parameters())
261
- for name, loaded_weight in hf_model_weights_iterator(
262
- model_name_or_path, cache_dir, load_format, revision
263
- ):
259
+ for name, loaded_weight in weights:
264
260
  if "rotary_emb.inv_freq" in name:
265
261
  continue
266
262
  for param_name, weight_name, shard_id in stacked_params_mapping:
@@ -1,10 +1,11 @@
1
1
  # Adapted from llama2.py
2
2
  # Modify details for the adaptation of Qwen2 model.
3
3
  """Inference-only Qwen2 model compatible with HuggingFace weights."""
4
- from typing import Any, Dict, Optional, Tuple
4
+ from typing import Any, Dict, Optional, Tuple, Iterable
5
5
 
6
6
  import torch
7
7
  from torch import nn
8
+ from vllm.config import CacheConfig
8
9
  from vllm.distributed import get_tensor_model_parallel_world_size
9
10
  from vllm.model_executor.layers.activation import SiluAndMul
10
11
  from vllm.model_executor.layers.layernorm import RMSNorm
@@ -19,11 +20,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
19
20
  ParallelLMHead,
20
21
  VocabParallelEmbedding,
21
22
  )
23
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
22
24
 
23
25
  from sglang.srt.layers.logits_processor import LogitsProcessor
24
26
  from sglang.srt.layers.radix_attention import RadixAttention
25
- from sglang.srt.managers.router.model_runner import InputMetadata
26
- from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
27
+ from sglang.srt.managers.controller.model_runner import InputMetadata
27
28
 
28
29
  Qwen2Config = None
29
30
 
@@ -251,6 +252,7 @@ class Qwen2ForCausalLM(nn.Module):
251
252
  self,
252
253
  config: Qwen2Config,
253
254
  quant_config: Optional[QuantizationConfig] = None,
255
+ cache_config: Optional[CacheConfig] = None,
254
256
  ) -> None:
255
257
  super().__init__()
256
258
  self.config = config
@@ -271,13 +273,7 @@ class Qwen2ForCausalLM(nn.Module):
271
273
  input_ids, hidden_states, self.lm_head.weight, input_metadata
272
274
  )
273
275
 
274
- def load_weights(
275
- self,
276
- model_name_or_path: str,
277
- cache_dir: Optional[str] = None,
278
- load_format: str = "auto",
279
- revision: Optional[str] = None,
280
- ):
276
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
281
277
  stacked_params_mapping = [
282
278
  # (param_name, shard_name, shard_id)
283
279
  ("qkv_proj", "q_proj", "q"),
@@ -287,9 +283,7 @@ class Qwen2ForCausalLM(nn.Module):
287
283
  ("gate_up_proj", "up_proj", 1),
288
284
  ]
289
285
  params_dict = dict(self.named_parameters())
290
- for name, loaded_weight in hf_model_weights_iterator(
291
- model_name_or_path, cache_dir, load_format, revision
292
- ):
286
+ for name, loaded_weight in weights:
293
287
  if "rotary_emb.inv_freq" in name or "projector" in name:
294
288
  continue
295
289
  if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
@@ -303,6 +297,8 @@ class Qwen2ForCausalLM(nn.Module):
303
297
  # Skip loading extra bias for GPTQ models.
304
298
  if name.endswith(".bias") and name not in params_dict:
305
299
  continue
300
+ if name.startswith("model.vision_tower") and name not in params_dict:
301
+ continue
306
302
  param = params_dict[name]
307
303
  weight_loader = param.weight_loader
308
304
  weight_loader(param, loaded_weight, shard_id)
@@ -311,6 +307,8 @@ class Qwen2ForCausalLM(nn.Module):
311
307
  # Skip loading extra bias for GPTQ models.
312
308
  if name.endswith(".bias") and name not in params_dict:
313
309
  continue
310
+ if name.startswith("model.vision_tower") and name not in params_dict:
311
+ continue
314
312
  param = params_dict[name]
315
313
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
316
314
  weight_loader(param, loaded_weight)
@@ -1,12 +1,13 @@
1
- # This code is based on:
2
- # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/stablelm.py
1
+ # Adapted from:
2
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/stablelm.py#L1
3
3
  """Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b)
4
4
  model compatible with HuggingFace weights."""
5
- from typing import Optional, Tuple
5
+ from typing import Optional, Tuple, Iterable
6
6
 
7
7
  import torch
8
8
  from torch import nn
9
9
  from transformers import PretrainedConfig
10
+ from vllm.config import CacheConfig
10
11
  from vllm.distributed import get_tensor_model_parallel_world_size
11
12
  from vllm.model_executor.layers.activation import SiluAndMul
12
13
  from vllm.model_executor.layers.linear import (
@@ -20,11 +21,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
20
21
  ParallelLMHead,
21
22
  VocabParallelEmbedding,
22
23
  )
24
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
23
25
 
24
26
  from sglang.srt.layers.logits_processor import LogitsProcessor
25
27
  from sglang.srt.layers.radix_attention import RadixAttention
26
- from sglang.srt.managers.router.model_runner import InputMetadata
27
- from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
28
+ from sglang.srt.managers.controller.model_runner import InputMetadata
28
29
 
29
30
 
30
31
  class StablelmMLP(nn.Module):
@@ -225,6 +226,7 @@ class StableLmForCausalLM(nn.Module):
225
226
  self,
226
227
  config: PretrainedConfig,
227
228
  quant_config: Optional[QuantizationConfig] = None,
229
+ cache_config: Optional[CacheConfig] = None,
228
230
  ) -> None:
229
231
  super().__init__()
230
232
  self.config = config
@@ -245,13 +247,7 @@ class StableLmForCausalLM(nn.Module):
245
247
  input_ids, hidden_states, self.lm_head.weight, input_metadata
246
248
  )
247
249
 
248
- def load_weights(
249
- self,
250
- model_name_or_path: str,
251
- cache_dir: Optional[str] = None,
252
- load_format: str = "auto",
253
- revision: Optional[str] = None,
254
- ):
250
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
255
251
  stacked_params_mapping = [
256
252
  # (param_name, shard_name, shard_id)
257
253
  ("qkv_proj", "q_proj", "q"),
@@ -261,9 +257,7 @@ class StableLmForCausalLM(nn.Module):
261
257
  ("gate_up_proj", "up_proj", 1),
262
258
  ]
263
259
  params_dict = dict(self.named_parameters())
264
- for name, loaded_weight in hf_model_weights_iterator(
265
- model_name_or_path, cache_dir, load_format, revision
266
- ):
260
+ for name, loaded_weight in weights:
267
261
  if "rotary_emb.inv_freq" in name:
268
262
  continue
269
263
  if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
sglang/srt/models/yivl.py CHANGED
@@ -1,40 +1,38 @@
1
1
  """Inference-only Yi-VL model."""
2
2
 
3
- import os
4
- from typing import List, Optional
3
+ from typing import Tuple, Iterable, Optional
5
4
 
6
5
  import torch
7
6
  import torch.nn as nn
8
7
  from transformers import CLIPVisionModel, LlavaConfig
8
+ from vllm.config import CacheConfig
9
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
9
10
 
11
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
10
12
  from sglang.srt.models.llava import (
11
13
  LlavaLlamaForCausalLM,
12
- clip_vision_embed_forward,
13
14
  monkey_path_clip_vision_embed_forward,
14
15
  )
15
- from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
16
16
 
17
17
 
18
18
  class YiVLForCausalLM(LlavaLlamaForCausalLM):
19
- def __init__(self, *args, **kwargs):
20
- self.config = kwargs["config"]
21
- super().__init__(self.config)
19
+ def __init__(
20
+ self,
21
+ config: LlavaConfig,
22
+ quant_config: Optional[QuantizationConfig] = None,
23
+ cache_config: Optional[CacheConfig] = None,
24
+ ) -> None:
25
+ super().__init__(config, quant_config, cache_config)
22
26
 
23
27
  self.multi_modal_projector = YiVLMultiModalProjector(self.config)
24
28
  self.vision_tower_subfolder = self.config.mm_vision_tower.replace(
25
29
  "./", ""
26
30
  ) # Everything after "./"
27
31
 
28
- def load_weights(
29
- self,
30
- model_name_or_path: str,
31
- cache_dir: Optional[str] = None,
32
- load_format: str = "auto",
33
- revision: Optional[str] = None,
34
- ):
32
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
35
33
  # We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B)
36
34
  self.vision_tower = CLIPVisionModel.from_pretrained(
37
- model_name_or_path,
35
+ self.config._name_or_path,
38
36
  torch_dtype=torch.float16,
39
37
  subfolder=self.vision_tower_subfolder,
40
38
  ).cuda()
@@ -68,9 +66,8 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
68
66
  "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
69
67
  }
70
68
  params_dict = dict(self.named_parameters())
71
- for name, loaded_weight in hf_model_weights_iterator(
72
- model_name_or_path, cache_dir, load_format, revision
73
- ):
69
+ weights = list(weights)
70
+ for name, loaded_weight in weights:
74
71
  if "projector" in name or "vision_tower" in name:
75
72
  for weight_name, param_name in projector_weights.items():
76
73
  if weight_name in name:
@@ -80,9 +77,7 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
80
77
  weight_loader(param, loaded_weight)
81
78
 
82
79
  # load language model
83
- self.language_model.load_weights(
84
- model_name_or_path, cache_dir, load_format, revision
85
- )
80
+ self.language_model.load_weights(weights)
86
81
 
87
82
  monkey_path_clip_vision_embed_forward()
88
83
 
@@ -103,7 +98,7 @@ class YiVLMultiModalProjector(nn.Module):
103
98
 
104
99
  def forward(self, image_features):
105
100
  hidden_states = self.linear_1(image_features)
106
- hidden_state = self.ln_1(hidden_states)
101
+ hidden_states = self.ln_1(hidden_states)
107
102
  hidden_states = self.act(hidden_states)
108
103
  hidden_states = self.linear_2(hidden_states)
109
104
  hidden_states = self.ln_2(hidden_states)