sglang 0.1.21__py3-none-any.whl → 0.1.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. sglang/__init__.py +8 -8
  2. sglang/api.py +1 -1
  3. sglang/backend/vertexai.py +5 -4
  4. sglang/bench.py +627 -0
  5. sglang/bench_latency.py +22 -19
  6. sglang/bench_serving.py +758 -0
  7. sglang/check_env.py +171 -0
  8. sglang/lang/backend/__init__.py +0 -0
  9. sglang/lang/backend/anthropic.py +77 -0
  10. sglang/lang/backend/base_backend.py +80 -0
  11. sglang/lang/backend/litellm.py +90 -0
  12. sglang/lang/backend/openai.py +438 -0
  13. sglang/lang/backend/runtime_endpoint.py +283 -0
  14. sglang/lang/backend/vertexai.py +149 -0
  15. sglang/lang/tracer.py +1 -1
  16. sglang/launch_server.py +1 -1
  17. sglang/launch_server_llavavid.py +1 -4
  18. sglang/srt/conversation.py +1 -1
  19. sglang/srt/layers/context_flashattention_nopad.py +0 -29
  20. sglang/srt/layers/extend_attention.py +0 -39
  21. sglang/srt/layers/linear.py +869 -0
  22. sglang/srt/layers/quantization/__init__.py +49 -0
  23. sglang/srt/layers/quantization/fp8.py +662 -0
  24. sglang/srt/layers/radix_attention.py +31 -5
  25. sglang/srt/layers/token_attention.py +1 -51
  26. sglang/srt/managers/controller/cuda_graph_runner.py +14 -12
  27. sglang/srt/managers/controller/infer_batch.py +47 -49
  28. sglang/srt/managers/controller/manager_multi.py +107 -100
  29. sglang/srt/managers/controller/manager_single.py +76 -96
  30. sglang/srt/managers/controller/model_runner.py +35 -23
  31. sglang/srt/managers/controller/tp_worker.py +127 -138
  32. sglang/srt/managers/detokenizer_manager.py +49 -5
  33. sglang/srt/managers/io_struct.py +36 -17
  34. sglang/srt/managers/tokenizer_manager.py +228 -125
  35. sglang/srt/memory_pool.py +19 -6
  36. sglang/srt/model_loader/model_loader.py +277 -0
  37. sglang/srt/model_loader/utils.py +260 -0
  38. sglang/srt/models/chatglm.py +1 -0
  39. sglang/srt/models/dbrx.py +1 -0
  40. sglang/srt/models/grok.py +1 -0
  41. sglang/srt/models/internlm2.py +317 -0
  42. sglang/srt/models/llama2.py +65 -16
  43. sglang/srt/models/llama_classification.py +1 -0
  44. sglang/srt/models/llava.py +1 -0
  45. sglang/srt/models/llavavid.py +1 -0
  46. sglang/srt/models/minicpm.py +1 -0
  47. sglang/srt/models/mixtral.py +1 -0
  48. sglang/srt/models/mixtral_quant.py +1 -0
  49. sglang/srt/models/qwen.py +1 -0
  50. sglang/srt/models/qwen2.py +6 -0
  51. sglang/srt/models/qwen2_moe.py +7 -4
  52. sglang/srt/models/stablelm.py +1 -0
  53. sglang/srt/openai_api/adapter.py +432 -0
  54. sglang/srt/openai_api/api_adapter.py +432 -0
  55. sglang/srt/openai_api/openai_api_adapter.py +431 -0
  56. sglang/srt/openai_api/openai_protocol.py +207 -0
  57. sglang/srt/openai_api/protocol.py +208 -0
  58. sglang/srt/openai_protocol.py +17 -0
  59. sglang/srt/sampling_params.py +2 -0
  60. sglang/srt/server.py +113 -84
  61. sglang/srt/server_args.py +23 -15
  62. sglang/srt/utils.py +16 -117
  63. sglang/test/test_conversation.py +1 -1
  64. sglang/test/test_openai_protocol.py +1 -1
  65. sglang/test/test_programs.py +1 -1
  66. sglang/test/test_utils.py +2 -2
  67. {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/METADATA +157 -167
  68. sglang-0.1.22.dist-info/RECORD +103 -0
  69. {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/WHEEL +1 -1
  70. sglang-0.1.21.dist-info/RECORD +0 -82
  71. {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/LICENSE +0 -0
  72. {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,317 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/7f62077af5159c625fe3ad1c812e6c1a2b93ba3b/vllm/model_executor/models/internlm2.py
3
+
4
+ from typing import Any, Dict, Iterable, Optional, Tuple
5
+
6
+ import torch
7
+ from torch import nn
8
+ from transformers import PretrainedConfig
9
+ from vllm.config import CacheConfig
10
+ from vllm.distributed import get_tensor_model_parallel_world_size
11
+ from vllm.model_executor.layers.activation import SiluAndMul
12
+ from vllm.model_executor.layers.layernorm import RMSNorm
13
+ from vllm.model_executor.layers.linear import (
14
+ MergedColumnParallelLinear,
15
+ QKVParallelLinear,
16
+ RowParallelLinear,
17
+ )
18
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
19
+ from vllm.model_executor.layers.rotary_embedding import get_rope
20
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
21
+ ParallelLMHead,
22
+ VocabParallelEmbedding,
23
+ )
24
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
25
+
26
+ from sglang.srt.layers.logits_processor import LogitsProcessor
27
+ from sglang.srt.layers.radix_attention import RadixAttention
28
+ from sglang.srt.managers.controller.model_runner import InputMetadata
29
+
30
+
31
+ class InternLM2MLP(nn.Module):
32
+
33
+ def __init__(
34
+ self,
35
+ hidden_size: int,
36
+ intermediate_size: int,
37
+ hidden_act: str,
38
+ quant_config: Optional[QuantizationConfig] = None,
39
+ ) -> None:
40
+ super().__init__()
41
+ self.gate_up_proj = MergedColumnParallelLinear(
42
+ hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
43
+ )
44
+ self.w2 = RowParallelLinear(
45
+ intermediate_size, hidden_size, bias=False, quant_config=quant_config
46
+ )
47
+ if hidden_act != "silu":
48
+ raise ValueError(
49
+ f"Unsupported activation: {hidden_act}. "
50
+ "Only silu is supported for now."
51
+ )
52
+ self.act_fn = SiluAndMul()
53
+
54
+ def forward(self, x):
55
+ gate_up, _ = self.gate_up_proj(x)
56
+ x = self.act_fn(gate_up)
57
+ x, _ = self.w2(x)
58
+ return x
59
+
60
+
61
+ class InternLM2Attention(nn.Module):
62
+
63
+ def __init__(
64
+ self,
65
+ hidden_size: int,
66
+ num_heads: int,
67
+ num_kv_heads: int,
68
+ rope_theta: float = 10000,
69
+ rope_scaling: Optional[Dict[str, Any]] = None,
70
+ max_position_embeddings: int = 8192,
71
+ layer_id: int = 0,
72
+ quant_config: Optional[QuantizationConfig] = None,
73
+ ) -> None:
74
+ super().__init__()
75
+ self.hidden_size = hidden_size
76
+ tp_size = get_tensor_model_parallel_world_size()
77
+ self.total_num_heads = num_heads
78
+ assert self.total_num_heads % tp_size == 0
79
+ self.num_heads = self.total_num_heads // tp_size
80
+ self.total_num_kv_heads = num_kv_heads
81
+ if self.total_num_kv_heads >= tp_size:
82
+ # Number of KV heads is greater than TP size, so we partition
83
+ # the KV heads across multiple tensor parallel GPUs.
84
+ assert self.total_num_kv_heads % tp_size == 0
85
+ else:
86
+ # Number of KV heads is less than TP size, so we replicate
87
+ # the KV heads across multiple tensor parallel GPUs.
88
+ assert tp_size % self.total_num_kv_heads == 0
89
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
90
+ self.head_dim = hidden_size // self.total_num_heads
91
+ self.q_size = self.num_heads * self.head_dim
92
+ self.kv_size = self.num_kv_heads * self.head_dim
93
+ self.scaling = self.head_dim**-0.5
94
+ self.rope_theta = rope_theta
95
+ self.max_position_embeddings = max_position_embeddings
96
+
97
+ self.wqkv = QKVParallelLinear(
98
+ hidden_size,
99
+ self.head_dim,
100
+ self.total_num_heads,
101
+ self.total_num_kv_heads,
102
+ bias=False,
103
+ quant_config=quant_config,
104
+ )
105
+ self.wo = RowParallelLinear(
106
+ self.total_num_heads * self.head_dim,
107
+ hidden_size,
108
+ bias=False,
109
+ quant_config=quant_config,
110
+ )
111
+
112
+ self.rotary_emb = get_rope(
113
+ self.head_dim,
114
+ rotary_dim=self.head_dim,
115
+ max_position=max_position_embeddings,
116
+ base=rope_theta,
117
+ rope_scaling=rope_scaling,
118
+ )
119
+ self.attn = RadixAttention(
120
+ self.num_heads, self.head_dim, self.scaling, self.num_kv_heads, layer_id
121
+ )
122
+
123
+ def forward(
124
+ self,
125
+ positions: torch.Tensor,
126
+ hidden_states: torch.Tensor,
127
+ input_metadata: InputMetadata,
128
+ ) -> torch.Tensor:
129
+ qkv, _ = self.wqkv(hidden_states)
130
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
131
+ q, k = self.rotary_emb(positions, q, k)
132
+ attn_output = self.attn(q, k, v, input_metadata)
133
+ output, _ = self.wo(attn_output)
134
+ return output
135
+
136
+
137
+ class InternLMDecoderLayer(nn.Module):
138
+
139
+ def __init__(
140
+ self,
141
+ config: PretrainedConfig,
142
+ layer_id: int = 0,
143
+ quant_config: Optional[QuantizationConfig] = None,
144
+ ) -> None:
145
+ super().__init__()
146
+ self.hidden_size = config.hidden_size
147
+ rope_theta = getattr(config, "rope_theta", 10000)
148
+ rope_scaling = getattr(config, "rope_scaling", None)
149
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
150
+ self.attention = InternLM2Attention(
151
+ hidden_size=self.hidden_size,
152
+ num_heads=config.num_attention_heads,
153
+ num_kv_heads=config.num_key_value_heads,
154
+ rope_theta=rope_theta,
155
+ rope_scaling=rope_scaling,
156
+ max_position_embeddings=max_position_embeddings,
157
+ layer_id=layer_id,
158
+ quant_config=quant_config,
159
+ )
160
+ self.feed_forward = InternLM2MLP(
161
+ hidden_size=self.hidden_size,
162
+ intermediate_size=config.intermediate_size,
163
+ hidden_act=config.hidden_act,
164
+ quant_config=quant_config,
165
+ )
166
+ self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
167
+ self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
168
+
169
+ def forward(
170
+ self,
171
+ positions: torch.Tensor,
172
+ hidden_states: torch.Tensor,
173
+ input_metadata: InputMetadata,
174
+ residual: Optional[torch.Tensor],
175
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
176
+ # Self Attention
177
+ if residual is None:
178
+ residual = hidden_states
179
+ hidden_states = self.attention_norm(hidden_states)
180
+ else:
181
+ hidden_states, residual = self.attention_norm(hidden_states, residual)
182
+ hidden_states = self.attention(
183
+ positions=positions,
184
+ hidden_states=hidden_states,
185
+ input_metadata=input_metadata,
186
+ )
187
+
188
+ # Fully Connected
189
+ hidden_states, residual = self.ffn_norm(hidden_states, residual)
190
+ hidden_states = self.feed_forward(hidden_states)
191
+ return hidden_states, residual
192
+
193
+
194
+ class InternLM2Model(nn.Module):
195
+
196
+ def __init__(
197
+ self,
198
+ config: PretrainedConfig,
199
+ quant_config: Optional[QuantizationConfig] = None,
200
+ ) -> None:
201
+ super().__init__()
202
+ self.config = config
203
+ self.padding_idx = config.pad_token_id
204
+ self.vocab_size = config.vocab_size
205
+ self.tok_embeddings = VocabParallelEmbedding(
206
+ config.vocab_size,
207
+ config.hidden_size,
208
+ )
209
+ self.layers = nn.ModuleList(
210
+ [
211
+ InternLMDecoderLayer(config, i, quant_config)
212
+ for i in range(config.num_hidden_layers)
213
+ ]
214
+ )
215
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
216
+
217
+ def forward(
218
+ self,
219
+ input_ids: torch.Tensor,
220
+ positions: torch.Tensor,
221
+ input_metadata: InputMetadata,
222
+ input_embeds: torch.Tensor = None,
223
+ ) -> torch.Tensor:
224
+ if input_embeds is None:
225
+ hidden_states = self.tok_embeddings(input_ids)
226
+ else:
227
+ hidden_states = input_embeds
228
+ residual = None
229
+ for i in range(len(self.layers)):
230
+ layer = self.layers[i]
231
+ hidden_states, residual = layer(
232
+ positions,
233
+ hidden_states,
234
+ input_metadata,
235
+ residual,
236
+ )
237
+ hidden_states, _ = self.norm(hidden_states, residual)
238
+ return hidden_states
239
+
240
+
241
+ class InternLM2ForCausalLM(nn.Module):
242
+
243
+ def __init__(
244
+ self,
245
+ config: PretrainedConfig,
246
+ quant_config: Optional[QuantizationConfig] = None,
247
+ cache_config: Optional[CacheConfig] = None,
248
+ ) -> None:
249
+ super().__init__()
250
+ self.config = config
251
+ self.quant_config = quant_config
252
+ self.model = InternLM2Model(config, quant_config)
253
+ self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
254
+ self.logits_processor = LogitsProcessor(config)
255
+
256
+ @torch.no_grad()
257
+ def forward(
258
+ self,
259
+ input_ids: torch.Tensor,
260
+ positions: torch.Tensor,
261
+ input_metadata: InputMetadata,
262
+ input_embeds: torch.Tensor = None,
263
+ ) -> torch.Tensor:
264
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
265
+ return self.logits_processor(
266
+ input_ids, hidden_states, self.output.weight, input_metadata
267
+ )
268
+
269
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
270
+ stacked_params_mapping = [
271
+ # (param_name, shard_name, shard_id)
272
+ ("gate_up_proj", "w1", 0),
273
+ ("gate_up_proj", "w3", 1),
274
+ ]
275
+ params_dict = dict(self.named_parameters())
276
+ for name, loaded_weight in weights:
277
+ if "rotary_emb.inv_freq" in name:
278
+ continue
279
+ for param_name, weight_name, shard_id in stacked_params_mapping:
280
+ if weight_name not in name:
281
+ continue
282
+ name = name.replace(weight_name, param_name)
283
+ # Skip loading extra bias for GPTQ models.
284
+ if name.endswith(".bias") and name not in params_dict:
285
+ continue
286
+ param = params_dict[name]
287
+ weight_loader = param.weight_loader
288
+ weight_loader(param, loaded_weight, shard_id)
289
+ break
290
+ else:
291
+ # Skip loading extra bias for GPTQ models.
292
+ if name.endswith(".bias") and name not in params_dict:
293
+ continue
294
+ param = params_dict[name]
295
+ if "wqkv" in name:
296
+ config = self.config
297
+ kv_groups = config.num_attention_heads // config.num_key_value_heads
298
+ head_dim = config.hidden_size // config.num_attention_heads
299
+ loaded_weight = loaded_weight.view(
300
+ -1, 2 + kv_groups, head_dim, loaded_weight.shape[-1]
301
+ )
302
+ wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1], dim=1)
303
+ wq = wq.reshape(-1, wq.shape[-1])
304
+ wk = wk.reshape(-1, wk.shape[-1])
305
+ wv = wv.reshape(-1, wv.shape[-1])
306
+ weight_loader = param.weight_loader
307
+ weight_loader(param, wq, "q")
308
+ weight_loader(param, wk, "k")
309
+ weight_loader(param, wv, "v")
310
+ else:
311
+ weight_loader = getattr(
312
+ param, "weight_loader", default_weight_loader
313
+ )
314
+ weight_loader(param, loaded_weight)
315
+
316
+
317
+ EntryClass = InternLM2ForCausalLM
@@ -15,11 +15,6 @@ from vllm.distributed import (
15
15
  )
16
16
  from vllm.model_executor.layers.activation import SiluAndMul
17
17
  from vllm.model_executor.layers.layernorm import RMSNorm
18
- from vllm.model_executor.layers.linear import (
19
- MergedColumnParallelLinear,
20
- QKVParallelLinear,
21
- RowParallelLinear,
22
- )
23
18
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
24
19
  from vllm.model_executor.layers.rotary_embedding import get_rope
25
20
  from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -32,6 +27,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
32
27
  from sglang.srt.layers.radix_attention import RadixAttention
33
28
  from sglang.srt.managers.controller.model_runner import InputMetadata
34
29
 
30
+ MergedColumnParallelLinear = None
31
+ QKVParallelLinear = None
32
+ RowParallelLinear = None
33
+
35
34
 
36
35
  class LlamaMLP(nn.Module):
37
36
  def __init__(
@@ -163,9 +162,9 @@ class LlamaDecoderLayer(nn.Module):
163
162
  if rope_scaling is not None and getattr(
164
163
  config, "original_max_position_embeddings", None
165
164
  ):
166
- rope_scaling[
167
- "original_max_position_embeddings"
168
- ] = config.original_max_position_embeddings
165
+ rope_scaling["original_max_position_embeddings"] = (
166
+ config.original_max_position_embeddings
167
+ )
169
168
  rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
170
169
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
171
170
  self.self_attn = LlamaAttention(
@@ -267,7 +266,25 @@ class LlamaForCausalLM(nn.Module):
267
266
  config: LlamaConfig,
268
267
  quant_config: Optional[QuantizationConfig] = None,
269
268
  cache_config: Optional[CacheConfig] = None,
269
+ efficient_weight_load=False,
270
270
  ) -> None:
271
+ global MergedColumnParallelLinear
272
+ global QKVParallelLinear
273
+ global RowParallelLinear
274
+
275
+ if efficient_weight_load:
276
+ from sglang.srt.layers.linear import (
277
+ MergedColumnParallelLinear,
278
+ QKVParallelLinear,
279
+ RowParallelLinear,
280
+ )
281
+ else:
282
+ from vllm.model_executor.layers.linear import (
283
+ MergedColumnParallelLinear,
284
+ QKVParallelLinear,
285
+ RowParallelLinear,
286
+ )
287
+
271
288
  super().__init__()
272
289
  self.config = config
273
290
  self.quant_config = quant_config
@@ -275,6 +292,7 @@ class LlamaForCausalLM(nn.Module):
275
292
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
276
293
  self.logits_processor = LogitsProcessor(config)
277
294
 
295
+ @torch.no_grad()
278
296
  def forward(
279
297
  self,
280
298
  input_ids: torch.Tensor,
@@ -287,7 +305,30 @@ class LlamaForCausalLM(nn.Module):
287
305
  input_ids, hidden_states, self.lm_head.weight, input_metadata
288
306
  )
289
307
 
290
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
308
+ def get_module_name(self, name):
309
+ stacked_params_mapping = [
310
+ # (param_name, shard_name, shard_id, num_shard)
311
+ ("qkv_proj", "q_proj", "q", 3),
312
+ ("qkv_proj", "k_proj", "k", 3),
313
+ ("qkv_proj", "v_proj", "v", 3),
314
+ ("gate_up_proj", "gate_proj", 0, 2),
315
+ ("gate_up_proj", "up_proj", 1, 2),
316
+ ]
317
+ for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
318
+ if weight_name in name:
319
+ return (
320
+ name.replace(weight_name, param_name)[: -len(".weight")],
321
+ num_shard,
322
+ )
323
+ return name[: -len(".weight")], 1
324
+
325
+ def get_num_params(self):
326
+ params_dict = dict(self.named_parameters())
327
+ return len(params_dict)
328
+
329
+ def load_weights(
330
+ self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
331
+ ):
291
332
  stacked_params_mapping = [
292
333
  # (param_name, shard_name, shard_id)
293
334
  ("qkv_proj", "q_proj", "q"),
@@ -297,15 +338,14 @@ class LlamaForCausalLM(nn.Module):
297
338
  ("gate_up_proj", "up_proj", 1),
298
339
  ]
299
340
  params_dict = dict(self.named_parameters())
300
- if get_tensor_model_parallel_rank() == 0:
301
- weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
302
- for name, loaded_weight in weights:
341
+
342
+ def load_weights_per_param(name, loaded_weight):
303
343
  if "rotary_emb.inv_freq" in name or "projector" in name:
304
- continue
344
+ return
305
345
  if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
306
346
  # Models trained using ColossalAI may include these tensors in
307
347
  # the checkpoint. Skip them.
308
- continue
348
+ return
309
349
  for param_name, weight_name, shard_id in stacked_params_mapping:
310
350
  if weight_name not in name:
311
351
  continue
@@ -322,12 +362,21 @@ class LlamaForCausalLM(nn.Module):
322
362
  else:
323
363
  # Skip loading extra bias for GPTQ models.
324
364
  if name.endswith(".bias") and name not in params_dict:
325
- continue
365
+ return
326
366
  if name.startswith("model.vision_tower") and name not in params_dict:
327
- continue
367
+ return
328
368
  param = params_dict[name]
329
369
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
330
370
  weight_loader(param, loaded_weight)
331
371
 
372
+ if name is None or loaded_weight is None:
373
+ if get_tensor_model_parallel_rank() == 0:
374
+ weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
375
+
376
+ for name, loaded_weight in weights:
377
+ load_weights_per_param(name, loaded_weight)
378
+ else:
379
+ load_weights_per_param(name, loaded_weight)
380
+
332
381
 
333
382
  EntryClass = LlamaForCausalLM
@@ -31,6 +31,7 @@ class LlamaForClassification(nn.Module):
31
31
  )
32
32
  self.eos_token_id = config.eos_token_id
33
33
 
34
+ @torch.no_grad()
34
35
  def forward(
35
36
  self,
36
37
  input_ids: torch.Tensor,
@@ -95,6 +95,7 @@ class LlavaLlamaForCausalLM(nn.Module):
95
95
 
96
96
  return image_features
97
97
 
98
+ @torch.no_grad()
98
99
  def forward(
99
100
  self,
100
101
  input_ids: torch.LongTensor,
@@ -106,6 +106,7 @@ class LlavaVidForCausalLM(nn.Module):
106
106
 
107
107
  return image_features
108
108
 
109
+ @torch.no_grad()
109
110
  def forward(
110
111
  self,
111
112
  input_ids: torch.LongTensor,
@@ -283,6 +283,7 @@ class MiniCPMForCausalLM(nn.Module):
283
283
 
284
284
  self.logits_processor = LogitsProcessor(config)
285
285
 
286
+ @torch.no_grad()
286
287
  def forward(
287
288
  self,
288
289
  input_ids: torch.Tensor,
@@ -460,6 +460,7 @@ class MixtralForCausalLM(nn.Module):
460
460
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
461
461
  self.logits_processor = LogitsProcessor(config)
462
462
 
463
+ @torch.no_grad()
463
464
  def forward(
464
465
  self,
465
466
  input_ids: torch.Tensor,
@@ -322,6 +322,7 @@ class QuantMixtralForCausalLM(nn.Module):
322
322
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
323
323
  self.logits_processor = LogitsProcessor(config)
324
324
 
325
+ @torch.no_grad()
325
326
  def forward(
326
327
  self,
327
328
  input_ids: torch.Tensor,
sglang/srt/models/qwen.py CHANGED
@@ -237,6 +237,7 @@ class QWenLMHeadModel(nn.Module):
237
237
  self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
238
238
  self.logits_processor = LogitsProcessor(config)
239
239
 
240
+ @torch.no_grad()
240
241
  def forward(
241
242
  self,
242
243
  input_ids: torch.Tensor,
@@ -261,6 +261,7 @@ class Qwen2ForCausalLM(nn.Module):
261
261
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
262
262
  self.logits_processor = LogitsProcessor(config)
263
263
 
264
+ @torch.no_grad()
264
265
  def forward(
265
266
  self,
266
267
  input_ids: torch.Tensor,
@@ -312,6 +313,11 @@ class Qwen2ForCausalLM(nn.Module):
312
313
  param = params_dict[name]
313
314
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
314
315
  weight_loader(param, loaded_weight)
316
+ if (
317
+ self.config.tie_word_embeddings
318
+ and name == "model.embed_tokens.weight"
319
+ ):
320
+ weight_loader(params_dict["lm_head.weight"], loaded_weight)
315
321
 
316
322
 
317
323
  EntryClass = Qwen2ForCausalLM
@@ -355,6 +355,7 @@ class Qwen2MoeForCausalLM(nn.Module):
355
355
  self.logits_processor = LogitsProcessor(config)
356
356
  self.sampler = Sampler()
357
357
 
358
+ @torch.no_grad()
358
359
  def forward(
359
360
  self,
360
361
  input_ids: torch.Tensor,
@@ -400,9 +401,11 @@ class Qwen2MoeForCausalLM(nn.Module):
400
401
  # These are the weights for the experts
401
402
  # (param_name, weight_name, expert_id, shard_id)
402
403
  (
403
- "experts.w13_weight"
404
- if weight_name in ["gate_proj", "up_proj"]
405
- else "experts.w2_weight",
404
+ (
405
+ "experts.w13_weight"
406
+ if weight_name in ["gate_proj", "up_proj"]
407
+ else "experts.w2_weight"
408
+ ),
406
409
  f"experts.{expert_id}.{weight_name}.weight",
407
410
  expert_id,
408
411
  shard_id,
@@ -417,7 +420,7 @@ class Qwen2MoeForCausalLM(nn.Module):
417
420
  for name, loaded_weight in weights:
418
421
  if "rotary_emb.inv_freq" in name:
419
422
  continue
420
- for (param_name, weight_name, shard_id) in stacked_params_mapping:
423
+ for param_name, weight_name, shard_id in stacked_params_mapping:
421
424
  # Skip non-stacked layers and experts (experts handled below).
422
425
  if weight_name not in name:
423
426
  continue
@@ -235,6 +235,7 @@ class StableLmForCausalLM(nn.Module):
235
235
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
236
236
  self.logits_processor = LogitsProcessor(config)
237
237
 
238
+ @torch.no_grad()
238
239
  def forward(
239
240
  self,
240
241
  input_ids: torch.Tensor,