sglang 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +7 -7
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +158 -11
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/bench_latency.py +299 -0
  8. sglang/global_config.py +12 -2
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +114 -67
  11. sglang/lang/ir.py +28 -3
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +13 -6
  15. sglang/srt/constrained/fsm_cache.py +8 -2
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +3 -1
  19. sglang/srt/hf_transformers_utils.py +130 -1
  20. sglang/srt/layers/extend_attention.py +17 -0
  21. sglang/srt/layers/fused_moe.py +582 -0
  22. sglang/srt/layers/logits_processor.py +65 -32
  23. sglang/srt/layers/radix_attention.py +41 -7
  24. sglang/srt/layers/token_attention.py +16 -1
  25. sglang/srt/managers/controller/dp_worker.py +113 -0
  26. sglang/srt/managers/{router → controller}/infer_batch.py +242 -100
  27. sglang/srt/managers/controller/manager_multi.py +191 -0
  28. sglang/srt/managers/{router/manager.py → controller/manager_single.py} +34 -14
  29. sglang/srt/managers/{router → controller}/model_runner.py +262 -158
  30. sglang/srt/managers/{router → controller}/radix_cache.py +11 -1
  31. sglang/srt/managers/{router/scheduler.py → controller/schedule_heuristic.py} +9 -7
  32. sglang/srt/managers/{router/model_rpc.py → controller/tp_worker.py} +298 -267
  33. sglang/srt/managers/detokenizer_manager.py +42 -46
  34. sglang/srt/managers/io_struct.py +22 -12
  35. sglang/srt/managers/tokenizer_manager.py +151 -87
  36. sglang/srt/model_config.py +83 -5
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +10 -13
  39. sglang/srt/models/dbrx.py +9 -15
  40. sglang/srt/models/gemma.py +12 -15
  41. sglang/srt/models/grok.py +738 -0
  42. sglang/srt/models/llama2.py +26 -15
  43. sglang/srt/models/llama_classification.py +104 -0
  44. sglang/srt/models/llava.py +86 -19
  45. sglang/srt/models/llavavid.py +11 -20
  46. sglang/srt/models/mixtral.py +282 -103
  47. sglang/srt/models/mixtral_quant.py +372 -0
  48. sglang/srt/models/qwen.py +9 -13
  49. sglang/srt/models/qwen2.py +11 -13
  50. sglang/srt/models/stablelm.py +9 -15
  51. sglang/srt/models/yivl.py +17 -22
  52. sglang/srt/openai_api_adapter.py +150 -95
  53. sglang/srt/openai_protocol.py +11 -2
  54. sglang/srt/server.py +124 -48
  55. sglang/srt/server_args.py +128 -48
  56. sglang/srt/utils.py +234 -67
  57. sglang/test/test_programs.py +65 -3
  58. sglang/test/test_utils.py +32 -1
  59. sglang/utils.py +23 -4
  60. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/METADATA +40 -27
  61. sglang-0.1.18.dist-info/RECORD +78 -0
  62. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -417
  66. sglang-0.1.16.dist-info/RECORD +0 -72
  67. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,399 @@
1
+ # coding=utf-8
2
+ # Adapted from
3
+ # https://github.com/THUDM/ChatGLM2-6B
4
+ """Inference-only ChatGLM model compatible with THUDM weights."""
5
+ from typing import Iterable, List, Optional, Tuple
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import LayerNorm
10
+ from vllm.config import CacheConfig
11
+ from vllm.distributed import get_tensor_model_parallel_world_size
12
+ from vllm.model_executor.layers.activation import SiluAndMul
13
+ from vllm.model_executor.layers.layernorm import RMSNorm
14
+ from vllm.model_executor.layers.linear import (
15
+ MergedColumnParallelLinear,
16
+ QKVParallelLinear,
17
+ RowParallelLinear,
18
+ )
19
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
20
+ from vllm.model_executor.layers.rotary_embedding import get_rope
21
+ from vllm.model_executor.layers.sampler import Sampler
22
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
23
+ ParallelLMHead,
24
+ VocabParallelEmbedding,
25
+ )
26
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
27
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
28
+ from vllm.sequence import SamplerOutput
29
+ from vllm.transformers_utils.configs import ChatGLMConfig
30
+
31
+ from sglang.srt.layers.logits_processor import LogitsProcessor
32
+ from sglang.srt.layers.radix_attention import RadixAttention
33
+ from sglang.srt.managers.controller.model_runner import InputMetadata
34
+
35
+ LoraConfig = None
36
+
37
+
38
+ class GLMAttention(nn.Module):
39
+ def __init__(
40
+ self,
41
+ config,
42
+ layer_id: int = 0,
43
+ cache_config: Optional[CacheConfig] = None,
44
+ quant_config: Optional[QuantizationConfig] = None,
45
+ ):
46
+ super().__init__()
47
+ self.hidden_size = config.hidden_size
48
+ tp_size = get_tensor_model_parallel_world_size()
49
+ self.total_num_heads = config.num_attention_heads
50
+ assert self.total_num_heads % tp_size == 0
51
+ self.num_heads = self.total_num_heads // tp_size
52
+ self.multi_query_attention = config.multi_query_attention
53
+ self.total_num_kv_heads = (
54
+ config.multi_query_group_num
55
+ if config.multi_query_attention
56
+ else config.num_attention_heads
57
+ )
58
+ if self.total_num_kv_heads >= tp_size:
59
+ # Number of KV heads is greater than TP size, so we partition
60
+ # the KV heads across multiple tensor parallel GPUs.
61
+ assert self.total_num_kv_heads % tp_size == 0
62
+ else:
63
+ # Number of KV heads is less than TP size, so we replicate
64
+ # the KV heads across multiple tensor parallel GPUs.
65
+ assert tp_size % self.total_num_kv_heads == 0
66
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
67
+ self.head_dim = config.hidden_size // self.total_num_heads
68
+ self.q_size = self.num_heads * self.head_dim
69
+ self.kv_size = self.num_kv_heads * self.head_dim
70
+ self.scaling = self.head_dim**-0.5
71
+
72
+ self.query_key_value = QKVParallelLinear(
73
+ self.hidden_size,
74
+ self.head_dim,
75
+ self.total_num_heads,
76
+ self.total_num_kv_heads,
77
+ bias=config.add_bias_linear or config.add_qkv_bias,
78
+ quant_config=quant_config,
79
+ )
80
+ self.dense = RowParallelLinear(
81
+ self.total_num_heads * self.head_dim,
82
+ config.hidden_size,
83
+ bias=config.add_bias_linear,
84
+ quant_config=quant_config,
85
+ )
86
+
87
+ # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
88
+ rope_ratio = getattr(config, "rope_ratio", 1.0)
89
+ max_positions = getattr(config, "seq_length", 8192)
90
+ self.rotary_emb = get_rope(
91
+ self.head_dim,
92
+ rotary_dim=self.head_dim // 2,
93
+ max_position=max_positions,
94
+ base=10000 * rope_ratio,
95
+ is_neox_style=False,
96
+ )
97
+ self.attn = RadixAttention(
98
+ self.num_heads,
99
+ self.head_dim,
100
+ self.scaling,
101
+ num_kv_heads=self.num_kv_heads,
102
+ layer_id=layer_id,
103
+ )
104
+
105
+ def forward(
106
+ self,
107
+ hidden_states: torch.Tensor,
108
+ position_ids: torch.Tensor,
109
+ input_metadata: InputMetadata,
110
+ ) -> torch.Tensor:
111
+ qkv, _ = self.query_key_value(hidden_states)
112
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
113
+ q, k = self.rotary_emb(position_ids, q, k)
114
+ context_layer = self.attn(
115
+ q,
116
+ k,
117
+ v,
118
+ input_metadata,
119
+ )
120
+ attn_output, _ = self.dense(context_layer)
121
+ return attn_output
122
+
123
+
124
+ class GLMMLP(nn.Module):
125
+ """MLP.
126
+
127
+ MLP will take the input with h hidden state, project it to 4*h
128
+ hidden dimension, perform nonlinear transformation, and project the
129
+ state back into h hidden dimension.
130
+ """
131
+
132
+ def __init__(
133
+ self,
134
+ config,
135
+ quant_config: Optional[QuantizationConfig] = None,
136
+ ):
137
+ super().__init__()
138
+
139
+ self.add_bias = config.add_bias_linear
140
+
141
+ # Project to 4h.
142
+ self.dense_h_to_4h = MergedColumnParallelLinear(
143
+ config.hidden_size,
144
+ [config.ffn_hidden_size] * 2,
145
+ bias=config.add_bias_linear,
146
+ quant_config=quant_config,
147
+ )
148
+
149
+ self.activation_func = SiluAndMul()
150
+
151
+ # Project back to h.
152
+ self.dense_4h_to_h = RowParallelLinear(
153
+ config.ffn_hidden_size,
154
+ config.hidden_size,
155
+ bias=config.add_bias_linear,
156
+ quant_config=quant_config,
157
+ )
158
+
159
+ def forward(self, hidden_states):
160
+ # [s, b, 4hp]
161
+ intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
162
+ intermediate_parallel = self.activation_func(intermediate_parallel)
163
+ # [s, b, h]
164
+ output, _ = self.dense_4h_to_h(intermediate_parallel)
165
+ return output
166
+
167
+
168
+ class GLMBlock(nn.Module):
169
+ """A single transformer layer.
170
+
171
+ Transformer layer takes input with size [s, b, h] and returns an
172
+ output of the same size.
173
+ """
174
+
175
+ def __init__(
176
+ self,
177
+ config,
178
+ layer_id: int,
179
+ cache_config: Optional[CacheConfig] = None,
180
+ quant_config: Optional[QuantizationConfig] = None,
181
+ ):
182
+ super().__init__()
183
+ self.apply_residual_connection_post_layernorm = (
184
+ config.apply_residual_connection_post_layernorm
185
+ )
186
+
187
+ self.fp32_residual_connection = config.fp32_residual_connection
188
+
189
+ layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
190
+ # Layernorm on the input data.
191
+ self.input_layernorm = layer_norm_func(
192
+ config.hidden_size, eps=config.layernorm_epsilon
193
+ )
194
+
195
+ # Self attention.
196
+ self.self_attention = GLMAttention(config, layer_id, cache_config, quant_config)
197
+ self.hidden_dropout = config.hidden_dropout
198
+
199
+ # Layernorm on the attention output
200
+ self.post_attention_layernorm = layer_norm_func(
201
+ config.hidden_size, eps=config.layernorm_epsilon
202
+ )
203
+
204
+ # MLP
205
+ self.mlp = GLMMLP(config, quant_config)
206
+
207
+ def forward(
208
+ self,
209
+ hidden_states: torch.Tensor,
210
+ position_ids: torch.Tensor,
211
+ input_metadata: InputMetadata,
212
+ ) -> torch.Tensor:
213
+ # hidden_states: [num_tokens, h]
214
+ # Layer norm at the beginning of the transformer layer.
215
+ layernorm_output = self.input_layernorm(hidden_states)
216
+ # Self attention.
217
+ attention_output = self.self_attention(
218
+ hidden_states=layernorm_output,
219
+ position_ids=position_ids,
220
+ input_metadata=input_metadata,
221
+ )
222
+
223
+ # Residual connection.
224
+ if self.apply_residual_connection_post_layernorm:
225
+ residual = layernorm_output
226
+ else:
227
+ residual = hidden_states
228
+
229
+ layernorm_input = residual + attention_output
230
+
231
+ # Layer norm post the self attention.
232
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
233
+
234
+ # Second residual connection.
235
+ if self.apply_residual_connection_post_layernorm:
236
+ residual = layernorm_output
237
+ else:
238
+ residual = layernorm_input
239
+
240
+ output = self.mlp(layernorm_output) + residual
241
+
242
+ return output
243
+
244
+
245
+ class GLMTransformer(nn.Module):
246
+ """Transformer class."""
247
+
248
+ def __init__(
249
+ self,
250
+ config,
251
+ cache_config: Optional[CacheConfig] = None,
252
+ quant_config: Optional[QuantizationConfig] = None,
253
+ ):
254
+ super().__init__()
255
+ self.post_layer_norm = config.post_layer_norm
256
+
257
+ # Number of layers.
258
+ self.num_layers = config.num_layers
259
+
260
+ # Transformer layers.
261
+ self.layers = nn.ModuleList(
262
+ [
263
+ GLMBlock(config, i, cache_config, quant_config)
264
+ for i in range(self.num_layers)
265
+ ]
266
+ )
267
+
268
+ if self.post_layer_norm:
269
+ layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
270
+ # Final layer norm before output.
271
+ self.final_layernorm = layer_norm_func(
272
+ config.hidden_size, eps=config.layernorm_epsilon
273
+ )
274
+
275
+ def forward(
276
+ self,
277
+ hidden_states: torch.Tensor,
278
+ position_ids: torch.Tensor,
279
+ input_metadata: InputMetadata,
280
+ ) -> torch.Tensor:
281
+ for i in range(self.num_layers):
282
+ layer = self.layers[i]
283
+ hidden_states = layer(
284
+ hidden_states=hidden_states,
285
+ position_ids=position_ids,
286
+ input_metadata=input_metadata,
287
+ )
288
+ # Final layer norm.
289
+ if self.post_layer_norm:
290
+ hidden_states = self.final_layernorm(hidden_states)
291
+
292
+ return hidden_states
293
+
294
+
295
+ class ChatGLMModel(nn.Module):
296
+ def __init__(
297
+ self,
298
+ config,
299
+ cache_config: Optional[CacheConfig] = None,
300
+ quant_config: Optional[QuantizationConfig] = None,
301
+ ):
302
+ super().__init__()
303
+
304
+ self.embedding = VocabParallelEmbedding(
305
+ config.padded_vocab_size, config.hidden_size
306
+ )
307
+
308
+ self.num_layers = config.num_layers
309
+ self.multi_query_group_num = config.multi_query_group_num
310
+ self.kv_channels = config.kv_channels
311
+ self.encoder = GLMTransformer(config, cache_config, quant_config)
312
+
313
+ self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size)
314
+
315
+ def forward(
316
+ self,
317
+ input_ids: torch.Tensor,
318
+ position_ids: torch.Tensor,
319
+ input_metadata: InputMetadata,
320
+ ) -> torch.Tensor:
321
+ inputs_embeds = self.embedding(input_ids)
322
+
323
+ # Run encoder.
324
+ hidden_states = self.encoder(
325
+ hidden_states=inputs_embeds,
326
+ position_ids=position_ids,
327
+ input_metadata=input_metadata,
328
+ )
329
+ return hidden_states
330
+
331
+
332
+ class ChatGLMForCausalLM(nn.Module):
333
+ packed_modules_mapping = {
334
+ "query_key_value": ["query_key_value"],
335
+ "dense_h_to_4h": ["dense_h_to_4h"],
336
+ }
337
+ # LoRA specific attributes
338
+ supported_lora_modules = [
339
+ "query_key_value",
340
+ "dense",
341
+ "dense_h_to_4h",
342
+ "dense_4h_to_h",
343
+ ]
344
+ embedding_modules = {}
345
+ embedding_padding_modules = []
346
+
347
+ def __init__(
348
+ self,
349
+ config: ChatGLMConfig,
350
+ cache_config: Optional[CacheConfig] = None,
351
+ quant_config: Optional[QuantizationConfig] = None,
352
+ lora_config: Optional[LoraConfig] = None,
353
+ ):
354
+ super().__init__()
355
+ self.config: ChatGLMConfig = config
356
+ self.quant_config = quant_config
357
+ self.max_position_embeddings = getattr(config, "max_sequence_length", 8192)
358
+ self.transformer = ChatGLMModel(config, cache_config, quant_config)
359
+ self.lm_head = self.transformer.output_layer
360
+ self.logits_processor = LogitsProcessor(config)
361
+ self.sampler = Sampler()
362
+
363
+ def forward(
364
+ self,
365
+ input_ids: torch.Tensor,
366
+ positions: torch.Tensor,
367
+ input_metadata: InputMetadata,
368
+ ) -> torch.Tensor:
369
+ hidden_states = self.transformer(input_ids, positions, input_metadata)
370
+ return self.logits_processor(
371
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
372
+ )
373
+
374
+ def sample(
375
+ self,
376
+ logits: torch.Tensor,
377
+ sampling_metadata: SamplingMetadata,
378
+ ) -> Optional[SamplerOutput]:
379
+ next_tokens = self.sampler(logits, sampling_metadata)
380
+ return next_tokens
381
+
382
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
383
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
384
+ for name, loaded_weight in weights:
385
+ if "rotary_pos_emb.inv_freq" in name:
386
+ continue
387
+ if "word_embeddings" in name:
388
+ name = name.replace(".word_embeddings", "")
389
+ # Skip loading extra bias for GPTQ models.
390
+ if name.endswith(".bias") and name not in params_dict:
391
+ continue
392
+ param = params_dict[name]
393
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
394
+ weight_loader(param, loaded_weight)
395
+
396
+
397
+ EntryClass = ChatGLMForCausalLM
398
+ # compat: glm model.config class == ChatGLMModel
399
+ EntryClassRemapping = [("ChatGLMModel", ChatGLMForCausalLM)]
@@ -18,15 +18,19 @@
18
18
  # See the License for the specific language governing permissions and
19
19
  # limitations under the License.
20
20
 
21
+ # Adapted from
22
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/commandr.py#L1
23
+
21
24
  # This file is based on the LLama model definition file in transformers
22
25
  """PyTorch Cohere model."""
23
- from typing import Optional, Tuple
26
+ from typing import Iterable, Optional, Tuple
24
27
 
25
28
  import torch
26
29
  import torch.utils.checkpoint
27
30
  from torch import nn
28
31
  from torch.nn.parameter import Parameter
29
32
  from transformers import PretrainedConfig
33
+ from vllm.config import CacheConfig
30
34
  from vllm.distributed import (
31
35
  get_tensor_model_parallel_rank,
32
36
  get_tensor_model_parallel_world_size,
@@ -40,12 +44,12 @@ from vllm.model_executor.layers.linear import (
40
44
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
41
45
  from vllm.model_executor.layers.rotary_embedding import get_rope
42
46
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
47
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
43
48
  from vllm.model_executor.utils import set_weight_attrs
44
49
 
45
50
  from sglang.srt.layers.logits_processor import LogitsProcessor
46
51
  from sglang.srt.layers.radix_attention import RadixAttention
47
- from sglang.srt.managers.router.model_runner import InputMetadata
48
- from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
52
+ from sglang.srt.managers.controller.model_runner import InputMetadata
49
53
 
50
54
 
51
55
  @torch.compile
@@ -301,6 +305,7 @@ class CohereForCausalLM(nn.Module):
301
305
  self,
302
306
  config: PretrainedConfig,
303
307
  quant_config: Optional[QuantizationConfig] = None,
308
+ cache_config: Optional[CacheConfig] = None,
304
309
  ) -> None:
305
310
  super().__init__()
306
311
  self.config = config
@@ -324,13 +329,7 @@ class CohereForCausalLM(nn.Module):
324
329
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
325
330
  )
326
331
 
327
- def load_weights(
328
- self,
329
- model_name_or_path: str,
330
- cache_dir: Optional[str] = None,
331
- load_format: str = "auto",
332
- revision: Optional[str] = None,
333
- ):
332
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
334
333
  stacked_params_mapping = [
335
334
  # (param_name, shard_name, shard_id)
336
335
  ("qkv_proj", "q_proj", "q"),
@@ -341,9 +340,7 @@ class CohereForCausalLM(nn.Module):
341
340
  ]
342
341
  params_dict = dict(self.named_parameters())
343
342
  loaded_params = set()
344
- for name, loaded_weight in hf_model_weights_iterator(
345
- model_name_or_path, cache_dir, load_format, revision
346
- ):
343
+ for name, loaded_weight in weights:
347
344
  for param_name, shard_name, shard_id in stacked_params_mapping:
348
345
  if shard_name not in name:
349
346
  continue
sglang/srt/models/dbrx.py CHANGED
@@ -1,10 +1,11 @@
1
1
  # Adapted from:
2
- # https://github.com/vllm-project/vllm/blob/14ccd94c89d0ffd9da283545d93ab1dfea5da340/vllm/model_executor/models/dbrx.py
2
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/dbrx.py#L1
3
3
  # coding=utf-8
4
- from typing import Optional
4
+ from typing import Iterable, Optional, Tuple
5
5
 
6
6
  import torch
7
7
  import torch.nn as nn
8
+ from vllm.config import CacheConfig
8
9
  from vllm.distributed import (
9
10
  get_tensor_model_parallel_rank,
10
11
  get_tensor_model_parallel_world_size,
@@ -23,13 +24,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
23
24
  ParallelLMHead,
24
25
  VocabParallelEmbedding,
25
26
  )
27
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
28
  from vllm.model_executor.utils import set_weight_attrs
29
+ from vllm.transformers_utils.configs.dbrx import DbrxConfig
27
30
 
28
31
  from sglang.srt.layers.logits_processor import LogitsProcessor
29
32
  from sglang.srt.layers.radix_attention import RadixAttention
30
- from sglang.srt.managers.router.model_runner import InputMetadata
31
- from sglang.srt.models.dbrx_config import DbrxConfig
32
- from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
33
+ from sglang.srt.managers.controller.model_runner import InputMetadata
33
34
 
34
35
 
35
36
  class DbrxRouter(nn.Module):
@@ -352,6 +353,7 @@ class DbrxForCausalLM(nn.Module):
352
353
  self,
353
354
  config: DbrxConfig,
354
355
  quant_config: Optional[QuantizationConfig] = None,
356
+ cache_config: Optional[CacheConfig] = None,
355
357
  ):
356
358
  super().__init__()
357
359
  self.config = config
@@ -377,13 +379,7 @@ class DbrxForCausalLM(nn.Module):
377
379
  input_ids, hidden_states, self.lm_head.weight, input_metadata
378
380
  )
379
381
 
380
- def load_weights(
381
- self,
382
- model_name_or_path: str,
383
- cache_dir: Optional[str] = None,
384
- load_format: str = "auto",
385
- revision: Optional[str] = None,
386
- ):
382
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
387
383
  expert_params_mapping = [
388
384
  (
389
385
  "ws" if weight_name in ["w1", "v1"] else "w2s",
@@ -392,9 +388,7 @@ class DbrxForCausalLM(nn.Module):
392
388
  for weight_name in ["w1", "v1", "w2"]
393
389
  ]
394
390
  params_dict = dict(self.named_parameters(remove_duplicate=False))
395
- for name, loaded_weight in hf_model_weights_iterator(
396
- model_name_or_path, cache_dir, load_format, revision
397
- ):
391
+ for name, loaded_weight in weights:
398
392
  for param_name, weight_name in expert_params_mapping:
399
393
  if weight_name not in name:
400
394
  continue
@@ -1,12 +1,12 @@
1
1
  # Adapted from:
2
- # https://github.com/vllm-project/vllm/blob/d65fac2738f0287a41955b45df76a2d5a919bff6/vllm/model_executor/models/gemma.py
2
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/gemma.py#L1
3
3
  """Inference-only Gemma model compatible with HuggingFace weights."""
4
- from typing import Optional, Tuple
4
+ from typing import Iterable, Optional, Tuple
5
5
 
6
6
  import torch
7
7
  from torch import nn
8
8
  from transformers import PretrainedConfig
9
- from vllm.config import LoRAConfig
9
+ from vllm.config import CacheConfig, LoRAConfig
10
10
  from vllm.distributed import get_tensor_model_parallel_world_size
11
11
  from vllm.model_executor.layers.activation import GeluAndMul
12
12
  from vllm.model_executor.layers.layernorm import RMSNorm
@@ -18,11 +18,11 @@ from vllm.model_executor.layers.linear import (
18
18
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
19
19
  from vllm.model_executor.layers.rotary_embedding import get_rope
20
20
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
21
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
21
22
 
22
23
  from sglang.srt.layers.logits_processor import LogitsProcessor
23
24
  from sglang.srt.layers.radix_attention import RadixAttention
24
- from sglang.srt.managers.router.model_runner import InputMetadata
25
- from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
25
+ from sglang.srt.managers.controller.model_runner import InputMetadata
26
26
 
27
27
 
28
28
  class GemmaMLP(nn.Module):
@@ -264,6 +264,7 @@ class GemmaForCausalLM(nn.Module):
264
264
  config: PretrainedConfig,
265
265
  quant_config: Optional[QuantizationConfig] = None,
266
266
  lora_config: Optional[LoRAConfig] = None,
267
+ cache_config: Optional[CacheConfig] = None,
267
268
  ) -> None:
268
269
  del lora_config # Unused.
269
270
  super().__init__()
@@ -285,13 +286,7 @@ class GemmaForCausalLM(nn.Module):
285
286
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
286
287
  )
287
288
 
288
- def load_weights(
289
- self,
290
- model_name_or_path: str,
291
- cache_dir: Optional[str] = None,
292
- load_format: str = "auto",
293
- revision: Optional[str] = None,
294
- ):
289
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
295
290
  stacked_params_mapping = [
296
291
  # (param_name, shard_name, shard_id)
297
292
  ("qkv_proj", "q_proj", "q"),
@@ -302,9 +297,7 @@ class GemmaForCausalLM(nn.Module):
302
297
  ]
303
298
  params_dict = dict(self.named_parameters())
304
299
  loaded_params = set()
305
- for name, loaded_weight in hf_model_weights_iterator(
306
- model_name_or_path, cache_dir, load_format, revision
307
- ):
300
+ for name, loaded_weight in weights:
308
301
  for param_name, shard_name, shard_id in stacked_params_mapping:
309
302
  if shard_name not in name:
310
303
  continue
@@ -317,6 +310,10 @@ class GemmaForCausalLM(nn.Module):
317
310
  weight_loader(param, loaded_weight, shard_id)
318
311
  break
319
312
  else:
313
+ # lm_head is not used in vllm as it is tied with embed_token.
314
+ # To prevent errors, skip loading lm_head.weight.
315
+ if "lm_head.weight" in name:
316
+ continue
320
317
  # Skip loading extra bias for GPTQ models.
321
318
  if name.endswith(".bias") and name not in params_dict:
322
319
  continue