sglang 0.3.6.post3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_one_batch.py +4 -0
  2. sglang/bench_serving.py +13 -0
  3. sglang/check_env.py +1 -1
  4. sglang/srt/_custom_ops.py +118 -0
  5. sglang/srt/configs/device_config.py +17 -0
  6. sglang/srt/configs/load_config.py +84 -0
  7. sglang/srt/configs/model_config.py +161 -4
  8. sglang/srt/configs/qwen2vl.py +5 -8
  9. sglang/srt/constrained/outlines_backend.py +6 -1
  10. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  11. sglang/srt/distributed/__init__.py +3 -0
  12. sglang/srt/distributed/communication_op.py +34 -0
  13. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  14. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  15. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  16. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  17. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  18. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  21. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  22. sglang/srt/distributed/parallel_state.py +1275 -0
  23. sglang/srt/distributed/utils.py +223 -0
  24. sglang/srt/hf_transformers_utils.py +37 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +13 -15
  26. sglang/srt/layers/attention/torch_native_backend.py +285 -0
  27. sglang/srt/layers/fused_moe_patch.py +20 -11
  28. sglang/srt/layers/linear.py +1 -0
  29. sglang/srt/layers/logits_processor.py +17 -3
  30. sglang/srt/layers/quantization/__init__.py +34 -0
  31. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  32. sglang/srt/lora/lora.py +1 -1
  33. sglang/srt/managers/io_struct.py +48 -2
  34. sglang/srt/managers/schedule_batch.py +18 -14
  35. sglang/srt/managers/schedule_policy.py +7 -4
  36. sglang/srt/managers/scheduler.py +76 -20
  37. sglang/srt/managers/tokenizer_manager.py +166 -68
  38. sglang/srt/managers/tp_worker.py +36 -3
  39. sglang/srt/managers/tp_worker_overlap_thread.py +21 -3
  40. sglang/srt/model_executor/cuda_graph_runner.py +16 -7
  41. sglang/srt/model_executor/forward_batch_info.py +9 -4
  42. sglang/srt/model_executor/model_runner.py +136 -150
  43. sglang/srt/model_loader/__init__.py +34 -0
  44. sglang/srt/model_loader/loader.py +1139 -0
  45. sglang/srt/model_loader/utils.py +41 -0
  46. sglang/srt/model_loader/weight_utils.py +640 -0
  47. sglang/srt/models/baichuan.py +9 -10
  48. sglang/srt/models/chatglm.py +6 -15
  49. sglang/srt/models/commandr.py +2 -3
  50. sglang/srt/models/dbrx.py +2 -3
  51. sglang/srt/models/deepseek.py +4 -11
  52. sglang/srt/models/deepseek_v2.py +3 -11
  53. sglang/srt/models/exaone.py +2 -3
  54. sglang/srt/models/gemma.py +2 -6
  55. sglang/srt/models/gemma2.py +3 -14
  56. sglang/srt/models/gemma2_reward.py +0 -1
  57. sglang/srt/models/gpt2.py +5 -12
  58. sglang/srt/models/gpt_bigcode.py +6 -22
  59. sglang/srt/models/grok.py +3 -3
  60. sglang/srt/models/internlm2.py +2 -3
  61. sglang/srt/models/internlm2_reward.py +0 -1
  62. sglang/srt/models/llama.py +97 -27
  63. sglang/srt/models/llama_classification.py +1 -2
  64. sglang/srt/models/llama_embedding.py +1 -2
  65. sglang/srt/models/llama_reward.py +2 -3
  66. sglang/srt/models/llava.py +1 -4
  67. sglang/srt/models/llavavid.py +1 -2
  68. sglang/srt/models/minicpm.py +4 -7
  69. sglang/srt/models/minicpm3.py +6 -19
  70. sglang/srt/models/mixtral.py +12 -5
  71. sglang/srt/models/mixtral_quant.py +2 -3
  72. sglang/srt/models/mllama.py +3 -7
  73. sglang/srt/models/olmo.py +2 -8
  74. sglang/srt/models/olmo2.py +0 -1
  75. sglang/srt/models/olmoe.py +3 -5
  76. sglang/srt/models/phi3_small.py +8 -8
  77. sglang/srt/models/qwen.py +2 -3
  78. sglang/srt/models/qwen2.py +10 -9
  79. sglang/srt/models/qwen2_moe.py +4 -11
  80. sglang/srt/models/qwen2_vl.py +2 -6
  81. sglang/srt/models/registry.py +99 -0
  82. sglang/srt/models/stablelm.py +2 -3
  83. sglang/srt/models/torch_native_llama.py +6 -12
  84. sglang/srt/models/xverse.py +2 -4
  85. sglang/srt/models/xverse_moe.py +4 -11
  86. sglang/srt/models/yivl.py +2 -3
  87. sglang/srt/openai_api/adapter.py +9 -5
  88. sglang/srt/openai_api/protocol.py +1 -0
  89. sglang/srt/server.py +267 -170
  90. sglang/srt/server_args.py +65 -31
  91. sglang/srt/utils.py +245 -28
  92. sglang/test/test_utils.py +7 -0
  93. sglang/version.py +1 -1
  94. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/METADATA +1 -1
  95. sglang-0.4.0.dist-info/RECORD +184 -0
  96. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  97. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
  98. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
  99. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -21,9 +21,7 @@ from typing import Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import GPTBigCodeConfig
24
- from vllm.config import LoRAConfig
25
24
  from vllm.distributed import get_tensor_model_parallel_world_size
26
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
27
25
 
28
26
  from sglang.srt.layers.activation import get_act_fn
29
27
  from sglang.srt.layers.linear import (
@@ -36,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
36
34
  from sglang.srt.layers.radix_attention import RadixAttention
37
35
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
38
36
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
37
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
39
38
 
40
39
 
41
40
  class GPTBigCodeAttention(nn.Module):
@@ -44,7 +43,6 @@ class GPTBigCodeAttention(nn.Module):
44
43
  self,
45
44
  layer_id: int,
46
45
  config: GPTBigCodeConfig,
47
- cache_config=None,
48
46
  quant_config: Optional[QuantizationConfig] = None,
49
47
  ):
50
48
  super().__init__()
@@ -145,7 +143,6 @@ class GPTBigCodeBlock(nn.Module):
145
143
  self,
146
144
  layer_id: int,
147
145
  config: GPTBigCodeConfig,
148
- cache_config=None,
149
146
  quant_config: Optional[QuantizationConfig] = None,
150
147
  ):
151
148
  super().__init__()
@@ -153,7 +150,7 @@ class GPTBigCodeBlock(nn.Module):
153
150
  inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
154
151
 
155
152
  self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
156
- self.attn = GPTBigCodeAttention(layer_id, config, cache_config, quant_config)
153
+ self.attn = GPTBigCodeAttention(layer_id, config, quant_config)
157
154
  self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
158
155
  self.mlp = GPTBigMLP(inner_dim, config, quant_config)
159
156
 
@@ -183,20 +180,14 @@ class GPTBigCodeModel(nn.Module):
183
180
  def __init__(
184
181
  self,
185
182
  config: GPTBigCodeConfig,
186
- cache_config=None,
187
183
  quant_config: Optional[QuantizationConfig] = None,
188
- lora_config: Optional[LoRAConfig] = None,
189
184
  ):
190
185
  super().__init__()
191
186
  self.config = config
192
187
  assert not config.add_cross_attention
193
188
 
194
189
  self.embed_dim = config.hidden_size
195
- lora_vocab = (
196
- (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
197
- if lora_config
198
- else 0
199
- )
190
+ lora_vocab = 0
200
191
  self.vocab_size = config.vocab_size + lora_vocab
201
192
  self.wte = VocabParallelEmbedding(
202
193
  self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size
@@ -204,7 +195,7 @@ class GPTBigCodeModel(nn.Module):
204
195
  self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
205
196
  self.h = nn.ModuleList(
206
197
  [
207
- GPTBigCodeBlock(i, config, cache_config, quant_config)
198
+ GPTBigCodeBlock(i, config, quant_config)
208
199
  for i in range(config.num_hidden_layers)
209
200
  ]
210
201
  )
@@ -243,23 +234,16 @@ class GPTBigCodeForCausalLM(nn.Module):
243
234
  def __init__(
244
235
  self,
245
236
  config: GPTBigCodeConfig,
246
- cache_config=None,
247
237
  quant_config: Optional[QuantizationConfig] = None,
248
- lora_config: Optional[LoRAConfig] = None,
249
238
  ):
250
239
  super().__init__()
251
240
 
252
241
  self.config = config
253
- self.lora_config = lora_config
254
242
 
255
243
  self.quant_config = quant_config
256
- self.transformer = GPTBigCodeModel(
257
- config, cache_config, quant_config, lora_config
258
- )
244
+ self.transformer = GPTBigCodeModel(config, quant_config)
259
245
  self.lm_head = self.transformer.wte
260
246
  self.unpadded_vocab_size = config.vocab_size
261
- if lora_config:
262
- self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
263
247
  self.logits_processor = LogitsProcessor(config)
264
248
 
265
249
  @torch.no_grad()
@@ -271,7 +255,7 @@ class GPTBigCodeForCausalLM(nn.Module):
271
255
  ) -> torch.Tensor:
272
256
  hidden_states = self.transformer(input_ids, positions, forward_batch)
273
257
  return self.logits_processor(
274
- input_ids, hidden_states, self.lm_head.weight, forward_batch
258
+ input_ids, hidden_states, self.lm_head, forward_batch
275
259
  )
276
260
 
277
261
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
sglang/srt/models/grok.py CHANGED
@@ -24,7 +24,6 @@ from torch import nn
24
24
  from transformers import PretrainedConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
27
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28
27
 
29
28
  from sglang.srt.layers.fused_moe_triton import FusedMoE
30
29
  from sglang.srt.layers.layernorm import RMSNorm
@@ -43,6 +42,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
43
42
  )
44
43
  from sglang.srt.managers.schedule_batch import global_server_args_dict
45
44
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
45
+ from sglang.srt.model_loader.loader import DefaultModelLoader
46
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
46
47
 
47
48
 
48
49
  class Grok1MoE(nn.Module):
@@ -285,7 +286,6 @@ class Grok1ForCausalLM(nn.Module):
285
286
  self,
286
287
  config: PretrainedConfig,
287
288
  quant_config: Optional[QuantizationConfig] = None,
288
- cache_config=None,
289
289
  ) -> None:
290
290
  super().__init__()
291
291
  self.config = config
@@ -304,7 +304,7 @@ class Grok1ForCausalLM(nn.Module):
304
304
  ) -> torch.Tensor:
305
305
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
306
306
  return self.logits_processor(
307
- input_ids, hidden_states, self.lm_head.weight, forward_batch
307
+ input_ids, hidden_states, self.lm_head, forward_batch
308
308
  )
309
309
 
310
310
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -21,7 +21,6 @@ from torch import nn
21
21
  from transformers import PretrainedConfig
22
22
  from vllm.distributed import get_tensor_model_parallel_world_size
23
23
  from vllm.model_executor.layers.rotary_embedding import get_rope
24
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
25
24
 
26
25
  from sglang.srt.layers.activation import SiluAndMul
27
26
  from sglang.srt.layers.layernorm import RMSNorm
@@ -38,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
38
37
  VocabParallelEmbedding,
39
38
  )
40
39
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
40
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
41
41
 
42
42
 
43
43
  class InternLM2MLP(nn.Module):
@@ -251,7 +251,6 @@ class InternLM2ForCausalLM(nn.Module):
251
251
  self,
252
252
  config: PretrainedConfig,
253
253
  quant_config: Optional[QuantizationConfig] = None,
254
- cache_config=None,
255
254
  ) -> None:
256
255
  super().__init__()
257
256
  self.config = config
@@ -270,7 +269,7 @@ class InternLM2ForCausalLM(nn.Module):
270
269
  ) -> torch.Tensor:
271
270
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
272
271
  return self.logits_processor(
273
- input_ids, hidden_states, self.output.weight, forward_batch
272
+ input_ids, hidden_states, self.output, forward_batch
274
273
  )
275
274
 
276
275
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -29,7 +29,6 @@ class InternLM2ForRewardModel(nn.Module):
29
29
  self,
30
30
  config: PretrainedConfig,
31
31
  quant_config: Optional[QuantizationConfig] = None,
32
- cache_config=None,
33
32
  ) -> None:
34
33
  super().__init__()
35
34
  self.config = config
@@ -16,6 +16,7 @@
16
16
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
17
17
  """Inference-only LLaMA model compatible with HuggingFace weights."""
18
18
 
19
+ import logging
19
20
  from typing import Any, Dict, Iterable, Optional, Tuple
20
21
 
21
22
  import torch
@@ -23,7 +24,6 @@ from torch import nn
23
24
  from transformers import LlamaConfig
24
25
  from vllm.distributed import get_tensor_model_parallel_world_size
25
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
26
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
27
27
 
28
28
  from sglang.srt.layers.activation import SiluAndMul
29
29
  from sglang.srt.layers.layernorm import RMSNorm
@@ -43,7 +43,11 @@ from sglang.srt.layers.vocab_parallel_embedding import (
43
43
  )
44
44
  from sglang.srt.managers.schedule_batch import global_server_args_dict
45
45
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
46
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
46
47
  from sglang.srt.utils import make_layers
48
+ from sglang.utils import get_exception_traceback
49
+
50
+ logger = logging.getLogger(__name__)
47
51
 
48
52
 
49
53
  class LlamaMLP(nn.Module):
@@ -255,6 +259,7 @@ class LlamaModel(nn.Module):
255
259
  self.embed_tokens = VocabParallelEmbedding(
256
260
  config.vocab_size,
257
261
  config.hidden_size,
262
+ quant_config=quant_config,
258
263
  )
259
264
  self.layers = make_layers(
260
265
  config.num_hidden_layers,
@@ -295,16 +300,30 @@ class LlamaForCausalLM(nn.Module):
295
300
  self,
296
301
  config: LlamaConfig,
297
302
  quant_config: Optional[QuantizationConfig] = None,
298
- cache_config=None,
299
303
  ) -> None:
300
304
  super().__init__()
301
305
  self.config = config
302
306
  self.quant_config = quant_config
303
307
  self.torchao_config = global_server_args_dict["torchao_config"]
304
308
  self.model = LlamaModel(config, quant_config=quant_config)
305
- self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
309
+ # Llama 3.2 1B Insturct set tie_word_embeddings to True
310
+ # Llama 3.1 8B Insturct set tie_word_embeddings to False
311
+ if self.config.tie_word_embeddings:
312
+ self.lm_head = self.model.embed_tokens
313
+ else:
314
+ self.lm_head = ParallelLMHead(
315
+ config.vocab_size, config.hidden_size, quant_config=quant_config
316
+ )
306
317
  self.logits_processor = LogitsProcessor(config)
307
318
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
319
+ self.stacked_params_mapping = [
320
+ # (param_name, shard_name, shard_id)
321
+ (".qkv_proj", ".q_proj", "q"),
322
+ (".qkv_proj", ".k_proj", "k"),
323
+ (".qkv_proj", ".v_proj", "v"),
324
+ (".gate_up_proj", ".gate_proj", 0),
325
+ (".gate_up_proj", ".up_proj", 1),
326
+ ]
308
327
 
309
328
  @torch.no_grad()
310
329
  def forward(
@@ -318,7 +337,7 @@ class LlamaForCausalLM(nn.Module):
318
337
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
319
338
  if not get_embedding:
320
339
  return self.logits_processor(
321
- input_ids, hidden_states, self.lm_head.weight, forward_batch
340
+ input_ids, hidden_states, self.lm_head, forward_batch
322
341
  )
323
342
  else:
324
343
  return self.pooler(hidden_states, forward_batch)
@@ -349,15 +368,7 @@ class LlamaForCausalLM(nn.Module):
349
368
  return params_mapping.get(name, name)
350
369
 
351
370
  def get_module_name_from_weight_name(self, name):
352
- stacked_params_mapping = [
353
- # (param_name, shard_name, shard_id, num_shard)
354
- ("qkv_proj", "q_proj", "q", 3),
355
- ("qkv_proj", "k_proj", "k", 3),
356
- ("qkv_proj", "v_proj", "v", 3),
357
- ("gate_up_proj", "gate_proj", 0, 2),
358
- ("gate_up_proj", "up_proj", 1, 2),
359
- ]
360
- for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
371
+ for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
361
372
  if weight_name in name:
362
373
  return (
363
374
  name.replace(weight_name, param_name)[: -len(".weight")],
@@ -378,13 +389,8 @@ class LlamaForCausalLM(nn.Module):
378
389
  (".gate_up_proj", ".gate_proj", 0),
379
390
  (".gate_up_proj", ".up_proj", 1),
380
391
  ]
381
- params_dict = dict(self.named_parameters())
382
392
 
383
- load_tie_word_embeddings = (
384
- hasattr(self.config, "tie_word_embeddings")
385
- and self.config.tie_word_embeddings
386
- and "lm_head.weight" in params_dict
387
- )
393
+ params_dict = dict(self.named_parameters())
388
394
 
389
395
  for name, loaded_weight in weights:
390
396
  if "rotary_emb.inv_freq" in name or "projector" in name:
@@ -418,16 +424,80 @@ class LlamaForCausalLM(nn.Module):
418
424
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
419
425
  weight_loader(param, loaded_weight)
420
426
 
421
- if load_tie_word_embeddings and name == "model.embed_tokens.weight":
422
- embed_tokens_weight = loaded_weight
427
+ apply_torchao_config_(self, params_dict, set(["proj.weight"]))
423
428
 
424
- if load_tie_word_embeddings:
425
- # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
426
- param = self.lm_head.weight
427
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
428
- weight_loader(param, embed_tokens_weight)
429
+ def get_weights_by_name(
430
+ self, name: str, truncate_size: int = 100, tp_size: int = 1
431
+ ) -> Optional[torch.Tensor]:
432
+ """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
433
+
434
+ Only used for unit test with an unoptimized performance.
435
+ For optimized performance, please use torch.save and torch.load.
436
+ """
437
+ try:
438
+ if name == "lm_head.weight" and self.config.tie_word_embeddings:
439
+ logger.info(
440
+ "word embedding is tied for this model, return embed_tokens.weight as lm_head.weight."
441
+ )
442
+ return (
443
+ self.model.embed_tokens.weight.cpu()
444
+ .to(torch.float32)
445
+ .numpy()
446
+ .tolist()[:truncate_size]
447
+ )
429
448
 
430
- apply_torchao_config_(self, params_dict, set(["proj.weight"]))
449
+ mapped_name = name
450
+ mapped_shard_id = None
451
+ for param_name, weight_name, shard_id in self.stacked_params_mapping:
452
+ if weight_name in name:
453
+ mapped_name = name.replace(weight_name, param_name)
454
+ mapped_shard_id = shard_id
455
+ break
456
+ params_dict = dict(self.named_parameters())
457
+ param = params_dict[mapped_name]
458
+ if mapped_shard_id is not None:
459
+ if mapped_shard_id in ["q", "k", "v"]:
460
+ num_heads = self.config.num_attention_heads // tp_size
461
+ num_kv_heads = self.config.num_key_value_heads // tp_size
462
+ head_dim = (
463
+ self.config.hidden_size // self.config.num_attention_heads
464
+ )
465
+ if mapped_shard_id == "q":
466
+ offset = 0
467
+ size = num_heads * head_dim
468
+ elif mapped_shard_id == "k":
469
+ offset = num_heads * head_dim
470
+ size = num_kv_heads * head_dim
471
+ elif mapped_shard_id == "v":
472
+ offset = (num_heads + num_kv_heads) * head_dim
473
+ size = num_kv_heads * head_dim
474
+ weight = param.data.narrow(0, offset, size)
475
+ elif mapped_shard_id in [0, 1]:
476
+ intermediate_size = self.config.intermediate_size
477
+ slice_size = intermediate_size // tp_size
478
+ if mapped_shard_id == 0: # gate_proj
479
+ offset = 0
480
+ size = slice_size
481
+ elif mapped_shard_id == 1: # up_proj
482
+ offset = slice_size
483
+ size = slice_size
484
+
485
+ weight = param.data.narrow(0, offset, size)
486
+ else:
487
+ weight = param.data
488
+ else:
489
+ weight = param.data
490
+ if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
491
+ gathered_weights = [torch.zeros_like(weight) for _ in range(tp_size)]
492
+ torch.distributed.all_gather(gathered_weights, weight)
493
+ weight = torch.cat(gathered_weights, dim=1)
494
+ return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
495
+
496
+ except Exception:
497
+ logger.error(
498
+ f"Error getting weights by name {name} in LlamaForCausalLM: {get_exception_traceback()}"
499
+ )
500
+ return None
431
501
 
432
502
 
433
503
  class Phi3ForCausalLM(LlamaForCausalLM):
@@ -17,11 +17,11 @@ from typing import Iterable, Optional, Tuple
17
17
  import torch
18
18
  from torch import nn
19
19
  from transformers import LlamaConfig
20
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
21
20
 
22
21
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
23
22
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
24
23
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
25
25
  from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
26
26
 
27
27
 
@@ -30,7 +30,6 @@ class LlamaForClassification(nn.Module):
30
30
  self,
31
31
  config: LlamaConfig,
32
32
  quant_config: Optional[QuantizationConfig] = None,
33
- cache_config=None,
34
33
  ) -> None:
35
34
  super().__init__()
36
35
  self.config = config
@@ -3,10 +3,10 @@ from typing import Iterable, Tuple
3
3
  import torch
4
4
  from torch import nn
5
5
  from transformers import LlamaConfig
6
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
7
6
 
8
7
  from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
9
8
  from sglang.srt.model_executor.model_runner import ForwardBatch
9
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
10
10
  from sglang.srt.models.llama import LlamaModel
11
11
 
12
12
 
@@ -15,7 +15,6 @@ class LlamaEmbeddingModel(nn.Module):
15
15
  self,
16
16
  config: LlamaConfig,
17
17
  quant_config=None,
18
- cache_config=None,
19
18
  ) -> None:
20
19
  super().__init__()
21
20
  self.model = LlamaModel(config, quant_config=quant_config)
@@ -21,6 +21,7 @@ from transformers import LlamaConfig
21
21
  from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
22
22
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
23
23
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
24
25
  from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
25
26
 
26
27
 
@@ -29,7 +30,6 @@ class LlamaForSequenceClassification(nn.Module):
29
30
  self,
30
31
  config: LlamaConfig,
31
32
  quant_config: Optional[QuantizationConfig] = None,
32
- cache_config=None,
33
33
  ) -> None:
34
34
  super().__init__()
35
35
  self.config = config
@@ -84,9 +84,8 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
84
84
  self,
85
85
  config: LlamaConfig,
86
86
  quant_config: Optional[QuantizationConfig] = None,
87
- cache_config=None,
88
87
  ) -> None:
89
- super().__init__(config, quant_config, cache_config)
88
+ super().__init__(config, quant_config)
90
89
  self.weights = self.Weights(config.hidden_size, self.num_labels)
91
90
 
92
91
  @torch.no_grad()
@@ -29,7 +29,6 @@ from transformers import (
29
29
  SiglipVisionModel,
30
30
  )
31
31
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
32
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
33
32
 
34
33
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
35
34
  from sglang.srt.managers.schedule_batch import ImageInputs
@@ -39,6 +38,7 @@ from sglang.srt.mm_utils import (
39
38
  unpad_image_shape,
40
39
  )
41
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
41
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
42
42
  from sglang.srt.models.llama import LlamaForCausalLM
43
43
  from sglang.srt.models.mistral import MistralForCausalLM
44
44
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
@@ -451,7 +451,6 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
451
451
  self,
452
452
  config: LlavaConfig,
453
453
  quant_config: Optional[QuantizationConfig] = None,
454
- cache_config=None,
455
454
  ) -> None:
456
455
  super().__init__()
457
456
 
@@ -473,7 +472,6 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
473
472
  self,
474
473
  config: LlavaConfig,
475
474
  quant_config: Optional[QuantizationConfig] = None,
476
- cache_config=None,
477
475
  ) -> None:
478
476
  super().__init__()
479
477
 
@@ -506,7 +504,6 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
506
504
  self,
507
505
  config: LlavaConfig,
508
506
  quant_config: Optional[QuantizationConfig] = None,
509
- cache_config=None,
510
507
  ) -> None:
511
508
  super().__init__()
512
509
 
@@ -20,11 +20,11 @@ import torch
20
20
  from torch import nn
21
21
  from transformers import CLIPVisionModel, LlavaConfig
22
22
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
23
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
24
23
 
25
24
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
26
25
  from sglang.srt.managers.schedule_batch import ImageInputs
27
26
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
27
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
28
28
  from sglang.srt.models.llama import LlamaForCausalLM
29
29
 
30
30
 
@@ -33,7 +33,6 @@ class LlavaVidForCausalLM(nn.Module):
33
33
  self,
34
34
  config: LlavaConfig,
35
35
  quant_config: Optional[QuantizationConfig] = None,
36
- cache_config=None,
37
36
  ) -> None:
38
37
  super().__init__()
39
38
  self.config = config
@@ -20,7 +20,6 @@ import torch
20
20
  from torch import nn
21
21
  from vllm.distributed import get_tensor_model_parallel_world_size
22
22
  from vllm.model_executor.layers.rotary_embedding import get_rope
23
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
24
23
 
25
24
  from sglang.srt.layers.activation import SiluAndMul
26
25
  from sglang.srt.layers.layernorm import RMSNorm
@@ -37,6 +36,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
37
36
  VocabParallelEmbedding,
38
37
  )
39
38
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
39
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
40
40
 
41
41
 
42
42
  class MiniCPMMLP(nn.Module):
@@ -275,7 +275,6 @@ class MiniCPMForCausalLM(nn.Module):
275
275
  self,
276
276
  config,
277
277
  quant_config: Optional[QuantizationConfig] = None,
278
- cache_config=None,
279
278
  ) -> None:
280
279
  super().__init__()
281
280
  self.config = config
@@ -308,12 +307,10 @@ class MiniCPMForCausalLM(nn.Module):
308
307
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
309
308
  hidden_states = hidden_states / self.scale_width
310
309
  if self.config.tie_word_embeddings:
311
- lm_head_weight = self.model.embed_tokens.weight
310
+ lm_head = self.model.embed_tokens
312
311
  else:
313
- lm_head_weight = self.lm_head.weight
314
- return self.logits_processor(
315
- input_ids, hidden_states, lm_head_weight, forward_batch
316
- )
312
+ lm_head = self.lm_head
313
+ return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)
317
314
 
318
315
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
319
316
  stacked_params_mapping = [
@@ -27,7 +27,6 @@ from vllm.model_executor.layers.linear import (
27
27
  RowParallelLinear,
28
28
  )
29
29
  from vllm.model_executor.layers.rotary_embedding import get_rope
30
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
30
 
32
31
  from sglang.srt.layers.activation import SiluAndMul
33
32
  from sglang.srt.layers.layernorm import RMSNorm
@@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
40
39
  )
41
40
  from sglang.srt.managers.schedule_batch import global_server_args_dict
42
41
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
43
43
  from sglang.srt.utils import is_flashinfer_available
44
44
 
45
45
  if is_flashinfer_available():
@@ -105,7 +105,6 @@ class MiniCPM3Attention(nn.Module):
105
105
  rope_theta: float = 10000,
106
106
  rope_scaling: Optional[Dict[str, Any]] = None,
107
107
  max_position_embeddings: int = 8192,
108
- cache_config=None,
109
108
  quant_config: Optional[QuantizationConfig] = None,
110
109
  layer_id=None,
111
110
  ) -> None:
@@ -249,7 +248,6 @@ class MiniCPM3AttentionMLA(nn.Module):
249
248
  rope_theta: float = 10000,
250
249
  rope_scaling: Optional[Dict[str, Any]] = None,
251
250
  max_position_embeddings: int = 8192,
252
- cache_config=None,
253
251
  quant_config: Optional[QuantizationConfig] = None,
254
252
  layer_id=None,
255
253
  ) -> None:
@@ -406,7 +404,6 @@ class MiniCPM3DecoderLayer(nn.Module):
406
404
  self,
407
405
  config: PretrainedConfig,
408
406
  layer_id: int,
409
- cache_config=None,
410
407
  quant_config: Optional[QuantizationConfig] = None,
411
408
  ) -> None:
412
409
  super().__init__()
@@ -430,7 +427,6 @@ class MiniCPM3DecoderLayer(nn.Module):
430
427
  rope_theta=rope_theta,
431
428
  rope_scaling=rope_scaling,
432
429
  max_position_embeddings=max_position_embeddings,
433
- cache_config=cache_config,
434
430
  quant_config=quant_config,
435
431
  layer_id=layer_id,
436
432
  )
@@ -449,7 +445,6 @@ class MiniCPM3DecoderLayer(nn.Module):
449
445
  rope_theta=rope_theta,
450
446
  rope_scaling=rope_scaling,
451
447
  max_position_embeddings=max_position_embeddings,
452
- cache_config=cache_config,
453
448
  quant_config=quant_config,
454
449
  layer_id=layer_id,
455
450
  )
@@ -498,7 +493,6 @@ class MiniCPM3Model(nn.Module):
498
493
  def __init__(
499
494
  self,
500
495
  config: PretrainedConfig,
501
- cache_config=None,
502
496
  quant_config: Optional[QuantizationConfig] = None,
503
497
  ) -> None:
504
498
  super().__init__()
@@ -512,9 +506,7 @@ class MiniCPM3Model(nn.Module):
512
506
  )
513
507
  self.layers = nn.ModuleList(
514
508
  [
515
- MiniCPM3DecoderLayer(
516
- config, i, cache_config=cache_config, quant_config=quant_config
517
- )
509
+ MiniCPM3DecoderLayer(config, i, quant_config=quant_config)
518
510
  for i in range(config.num_hidden_layers)
519
511
  ]
520
512
  )
@@ -549,7 +541,6 @@ class MiniCPM3ForCausalLM(nn.Module):
549
541
  def __init__(
550
542
  self,
551
543
  config: PretrainedConfig,
552
- cache_config=None,
553
544
  quant_config: Optional[QuantizationConfig] = None,
554
545
  ) -> None:
555
546
  super().__init__()
@@ -557,9 +548,7 @@ class MiniCPM3ForCausalLM(nn.Module):
557
548
 
558
549
  self.num_experts = getattr(self.config, "num_experts", 0)
559
550
  self.quant_config = quant_config
560
- self.model = MiniCPM3Model(
561
- config, cache_config=cache_config, quant_config=quant_config
562
- )
551
+ self.model = MiniCPM3Model(config, quant_config=quant_config)
563
552
  # self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
564
553
  if not self.config.tie_word_embeddings:
565
554
  self.lm_head = ParallelLMHead(
@@ -585,12 +574,10 @@ class MiniCPM3ForCausalLM(nn.Module):
585
574
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
586
575
  hidden_states = hidden_states / self.scale_width
587
576
  if self.config.tie_word_embeddings:
588
- lm_head_weight = self.model.embed_tokens.weight
577
+ lm_head = self.model.embed_tokens
589
578
  else:
590
- lm_head_weight = self.lm_head.weight
591
- return self.logits_processor(
592
- input_ids, hidden_states, lm_head_weight, forward_batch
593
- )
579
+ lm_head = self.lm_head
580
+ return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)
594
581
 
595
582
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
596
583
  stacked_params_mapping = [