sglang 0.3.6.post2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. sglang/bench_offline_throughput.py +55 -2
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +4 -3
  4. sglang/bench_serving.py +13 -0
  5. sglang/check_env.py +1 -1
  6. sglang/launch_server.py +3 -2
  7. sglang/srt/_custom_ops.py +118 -0
  8. sglang/srt/configs/device_config.py +17 -0
  9. sglang/srt/configs/load_config.py +84 -0
  10. sglang/srt/configs/model_config.py +161 -4
  11. sglang/srt/configs/qwen2vl.py +5 -8
  12. sglang/srt/constrained/outlines_backend.py +6 -1
  13. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  14. sglang/srt/distributed/__init__.py +3 -0
  15. sglang/srt/distributed/communication_op.py +34 -0
  16. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  17. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  19. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  20. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  21. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  22. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  23. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  24. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  25. sglang/srt/distributed/parallel_state.py +1275 -0
  26. sglang/srt/distributed/utils.py +223 -0
  27. sglang/srt/hf_transformers_utils.py +37 -1
  28. sglang/srt/layers/attention/flashinfer_backend.py +13 -15
  29. sglang/srt/layers/attention/torch_native_backend.py +285 -0
  30. sglang/srt/layers/fused_moe_patch.py +20 -11
  31. sglang/srt/layers/linear.py +1 -0
  32. sglang/srt/layers/logits_processor.py +17 -3
  33. sglang/srt/layers/quantization/__init__.py +34 -0
  34. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  35. sglang/srt/lora/lora.py +1 -1
  36. sglang/srt/managers/data_parallel_controller.py +7 -11
  37. sglang/srt/managers/detokenizer_manager.py +7 -4
  38. sglang/srt/managers/image_processor.py +1 -1
  39. sglang/srt/managers/io_struct.py +48 -12
  40. sglang/srt/managers/schedule_batch.py +42 -36
  41. sglang/srt/managers/schedule_policy.py +7 -4
  42. sglang/srt/managers/scheduler.py +111 -46
  43. sglang/srt/managers/session_controller.py +0 -3
  44. sglang/srt/managers/tokenizer_manager.py +169 -100
  45. sglang/srt/managers/tp_worker.py +36 -3
  46. sglang/srt/managers/tp_worker_overlap_thread.py +32 -5
  47. sglang/srt/model_executor/cuda_graph_runner.py +16 -7
  48. sglang/srt/model_executor/forward_batch_info.py +9 -4
  49. sglang/srt/model_executor/model_runner.py +136 -150
  50. sglang/srt/model_loader/__init__.py +34 -0
  51. sglang/srt/model_loader/loader.py +1139 -0
  52. sglang/srt/model_loader/utils.py +41 -0
  53. sglang/srt/model_loader/weight_utils.py +640 -0
  54. sglang/srt/models/baichuan.py +9 -10
  55. sglang/srt/models/chatglm.py +6 -15
  56. sglang/srt/models/commandr.py +2 -3
  57. sglang/srt/models/dbrx.py +2 -3
  58. sglang/srt/models/deepseek.py +4 -11
  59. sglang/srt/models/deepseek_v2.py +3 -11
  60. sglang/srt/models/exaone.py +2 -3
  61. sglang/srt/models/gemma.py +2 -6
  62. sglang/srt/models/gemma2.py +3 -14
  63. sglang/srt/models/gemma2_reward.py +0 -1
  64. sglang/srt/models/gpt2.py +5 -12
  65. sglang/srt/models/gpt_bigcode.py +6 -22
  66. sglang/srt/models/grok.py +14 -51
  67. sglang/srt/models/internlm2.py +2 -3
  68. sglang/srt/models/internlm2_reward.py +0 -1
  69. sglang/srt/models/llama.py +97 -27
  70. sglang/srt/models/llama_classification.py +1 -2
  71. sglang/srt/models/llama_embedding.py +1 -2
  72. sglang/srt/models/llama_reward.py +2 -3
  73. sglang/srt/models/llava.py +10 -12
  74. sglang/srt/models/llavavid.py +1 -2
  75. sglang/srt/models/minicpm.py +4 -7
  76. sglang/srt/models/minicpm3.py +6 -19
  77. sglang/srt/models/mixtral.py +12 -5
  78. sglang/srt/models/mixtral_quant.py +2 -3
  79. sglang/srt/models/mllama.py +3 -7
  80. sglang/srt/models/olmo.py +2 -8
  81. sglang/srt/models/olmo2.py +391 -0
  82. sglang/srt/models/olmoe.py +3 -5
  83. sglang/srt/models/phi3_small.py +8 -8
  84. sglang/srt/models/qwen.py +2 -3
  85. sglang/srt/models/qwen2.py +10 -9
  86. sglang/srt/models/qwen2_moe.py +4 -11
  87. sglang/srt/models/qwen2_vl.py +12 -9
  88. sglang/srt/models/registry.py +99 -0
  89. sglang/srt/models/stablelm.py +2 -3
  90. sglang/srt/models/torch_native_llama.py +6 -12
  91. sglang/srt/models/xverse.py +2 -4
  92. sglang/srt/models/xverse_moe.py +4 -11
  93. sglang/srt/models/yivl.py +2 -3
  94. sglang/srt/openai_api/adapter.py +10 -6
  95. sglang/srt/openai_api/protocol.py +1 -0
  96. sglang/srt/server.py +303 -204
  97. sglang/srt/server_args.py +65 -31
  98. sglang/srt/utils.py +253 -48
  99. sglang/test/test_utils.py +27 -7
  100. sglang/utils.py +2 -2
  101. sglang/version.py +1 -1
  102. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/METADATA +2 -1
  103. sglang-0.4.0.dist-info/RECORD +184 -0
  104. sglang/srt/layers/fused_moe_grok/__init__.py +0 -1
  105. sglang/srt/layers/fused_moe_grok/fused_moe.py +0 -692
  106. sglang/srt/layers/fused_moe_grok/layer.py +0 -630
  107. sglang-0.3.6.post2.dist-info/RECORD +0 -164
  108. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
  109. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
  110. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -21,9 +21,7 @@ from typing import Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import GPTBigCodeConfig
24
- from vllm.config import LoRAConfig
25
24
  from vllm.distributed import get_tensor_model_parallel_world_size
26
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
27
25
 
28
26
  from sglang.srt.layers.activation import get_act_fn
29
27
  from sglang.srt.layers.linear import (
@@ -36,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
36
34
  from sglang.srt.layers.radix_attention import RadixAttention
37
35
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
38
36
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
37
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
39
38
 
40
39
 
41
40
  class GPTBigCodeAttention(nn.Module):
@@ -44,7 +43,6 @@ class GPTBigCodeAttention(nn.Module):
44
43
  self,
45
44
  layer_id: int,
46
45
  config: GPTBigCodeConfig,
47
- cache_config=None,
48
46
  quant_config: Optional[QuantizationConfig] = None,
49
47
  ):
50
48
  super().__init__()
@@ -145,7 +143,6 @@ class GPTBigCodeBlock(nn.Module):
145
143
  self,
146
144
  layer_id: int,
147
145
  config: GPTBigCodeConfig,
148
- cache_config=None,
149
146
  quant_config: Optional[QuantizationConfig] = None,
150
147
  ):
151
148
  super().__init__()
@@ -153,7 +150,7 @@ class GPTBigCodeBlock(nn.Module):
153
150
  inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
154
151
 
155
152
  self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
156
- self.attn = GPTBigCodeAttention(layer_id, config, cache_config, quant_config)
153
+ self.attn = GPTBigCodeAttention(layer_id, config, quant_config)
157
154
  self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
158
155
  self.mlp = GPTBigMLP(inner_dim, config, quant_config)
159
156
 
@@ -183,20 +180,14 @@ class GPTBigCodeModel(nn.Module):
183
180
  def __init__(
184
181
  self,
185
182
  config: GPTBigCodeConfig,
186
- cache_config=None,
187
183
  quant_config: Optional[QuantizationConfig] = None,
188
- lora_config: Optional[LoRAConfig] = None,
189
184
  ):
190
185
  super().__init__()
191
186
  self.config = config
192
187
  assert not config.add_cross_attention
193
188
 
194
189
  self.embed_dim = config.hidden_size
195
- lora_vocab = (
196
- (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
197
- if lora_config
198
- else 0
199
- )
190
+ lora_vocab = 0
200
191
  self.vocab_size = config.vocab_size + lora_vocab
201
192
  self.wte = VocabParallelEmbedding(
202
193
  self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size
@@ -204,7 +195,7 @@ class GPTBigCodeModel(nn.Module):
204
195
  self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
205
196
  self.h = nn.ModuleList(
206
197
  [
207
- GPTBigCodeBlock(i, config, cache_config, quant_config)
198
+ GPTBigCodeBlock(i, config, quant_config)
208
199
  for i in range(config.num_hidden_layers)
209
200
  ]
210
201
  )
@@ -243,23 +234,16 @@ class GPTBigCodeForCausalLM(nn.Module):
243
234
  def __init__(
244
235
  self,
245
236
  config: GPTBigCodeConfig,
246
- cache_config=None,
247
237
  quant_config: Optional[QuantizationConfig] = None,
248
- lora_config: Optional[LoRAConfig] = None,
249
238
  ):
250
239
  super().__init__()
251
240
 
252
241
  self.config = config
253
- self.lora_config = lora_config
254
242
 
255
243
  self.quant_config = quant_config
256
- self.transformer = GPTBigCodeModel(
257
- config, cache_config, quant_config, lora_config
258
- )
244
+ self.transformer = GPTBigCodeModel(config, quant_config)
259
245
  self.lm_head = self.transformer.wte
260
246
  self.unpadded_vocab_size = config.vocab_size
261
- if lora_config:
262
- self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
263
247
  self.logits_processor = LogitsProcessor(config)
264
248
 
265
249
  @torch.no_grad()
@@ -271,7 +255,7 @@ class GPTBigCodeForCausalLM(nn.Module):
271
255
  ) -> torch.Tensor:
272
256
  hidden_states = self.transformer(input_ids, positions, forward_batch)
273
257
  return self.logits_processor(
274
- input_ids, hidden_states, self.lm_head.weight, forward_batch
258
+ input_ids, hidden_states, self.lm_head, forward_batch
275
259
  )
276
260
 
277
261
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
sglang/srt/models/grok.py CHANGED
@@ -16,22 +16,16 @@
16
16
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
17
17
  """Inference-only Grok1 model."""
18
18
 
19
- import warnings
20
- from typing import Iterable, List, Optional, Tuple
19
+ from typing import Iterable, Optional, Tuple
21
20
 
22
21
  import torch
23
22
  import torch.nn.functional as F
24
23
  from torch import nn
25
24
  from transformers import PretrainedConfig
26
- from vllm.distributed import (
27
- get_tensor_model_parallel_rank,
28
- get_tensor_model_parallel_world_size,
29
- )
25
+ from vllm.distributed import get_tensor_model_parallel_world_size
30
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
31
- from vllm.model_executor.model_loader.loader import DefaultModelLoader
32
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
33
27
 
34
- from sglang.srt.layers.fused_moe_grok import FusedMoE
28
+ from sglang.srt.layers.fused_moe_triton import FusedMoE
35
29
  from sglang.srt.layers.layernorm import RMSNorm
36
30
  from sglang.srt.layers.linear import (
37
31
  QKVParallelLinear,
@@ -41,11 +35,15 @@ from sglang.srt.layers.linear import (
41
35
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
43
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
+ from sglang.srt.layers.torchao_utils import apply_torchao_config_
44
39
  from sglang.srt.layers.vocab_parallel_embedding import (
45
40
  ParallelLMHead,
46
41
  VocabParallelEmbedding,
47
42
  )
43
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
48
44
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
45
+ from sglang.srt.model_loader.loader import DefaultModelLoader
46
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
49
47
 
50
48
 
51
49
  class Grok1MoE(nn.Module):
@@ -288,22 +286,15 @@ class Grok1ForCausalLM(nn.Module):
288
286
  self,
289
287
  config: PretrainedConfig,
290
288
  quant_config: Optional[QuantizationConfig] = None,
291
- cache_config=None,
292
289
  ) -> None:
293
290
  super().__init__()
294
291
  self.config = config
295
292
  self.quant_config = quant_config
293
+ self.torchao_config = global_server_args_dict["torchao_config"]
296
294
  self.model = Grok1Model(config, quant_config=quant_config)
297
295
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
298
296
  self.logits_processor = LogitsProcessor(config)
299
297
 
300
- # Monkey patch _prepare_weights to load pre-sharded weights
301
- setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
302
-
303
- self.use_presharded_weights = True
304
-
305
- warnings.filterwarnings("ignore", category=FutureWarning)
306
-
307
298
  def forward(
308
299
  self,
309
300
  input_ids: torch.Tensor,
@@ -313,7 +304,7 @@ class Grok1ForCausalLM(nn.Module):
313
304
  ) -> torch.Tensor:
314
305
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
315
306
  return self.logits_processor(
316
- input_ids, hidden_states, self.lm_head.weight, forward_batch
307
+ input_ids, hidden_states, self.lm_head, forward_batch
317
308
  )
318
309
 
319
310
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -357,28 +348,23 @@ class Grok1ForCausalLM(nn.Module):
357
348
  continue
358
349
  name = name.replace(weight_name, param_name)
359
350
 
360
- if self.use_presharded_weights:
361
- extra_kwargs = {
362
- "use_presharded_weights": self.use_presharded_weights
363
- }
364
- else:
365
- extra_kwargs = {}
366
-
367
351
  param = params_dict[name]
368
352
  weight_loader = param.weight_loader
369
353
  weight_loader(
370
354
  param,
371
355
  loaded_weight,
372
- weight_name,
356
+ name,
373
357
  shard_id=shard_id,
374
358
  expert_id=expert_id,
375
- **extra_kwargs,
376
359
  )
377
360
  break
378
361
  else:
379
362
  # Skip loading extra bias for GPTQ models.
380
363
  if name.endswith(".bias") and name not in params_dict:
381
364
  continue
365
+ # Skip loading kv_scale from ckpts towards new design.
366
+ if name.endswith(".kv_scale") and name not in params_dict:
367
+ continue
382
368
  if name is None:
383
369
  continue
384
370
 
@@ -388,30 +374,7 @@ class Grok1ForCausalLM(nn.Module):
388
374
  )
389
375
  weight_loader(param, loaded_weight)
390
376
 
391
-
392
- old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
393
-
394
-
395
- def _prepare_presharded_weights(
396
- self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
397
- ) -> Tuple[str, List[str], bool]:
398
- import glob
399
- import os
400
-
401
- if get_tensor_model_parallel_world_size() == 1:
402
- return old_prepare_weights(self, model_name_or_path, revision, fall_back_to_pt)
403
-
404
- tp_rank = get_tensor_model_parallel_rank()
405
- allow_patterns = [f"*-{tp_rank:03d}.bin"]
406
-
407
- hf_folder = model_name_or_path
408
-
409
- hf_weights_files: List[str] = []
410
- for pattern in allow_patterns:
411
- hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
412
- use_safetensors = False
413
-
414
- return hf_folder, hf_weights_files, use_safetensors
377
+ apply_torchao_config_(self, params_dict, set(["proj.weight"]))
415
378
 
416
379
 
417
380
  class Grok1ModelForCausalLM(Grok1ForCausalLM):
@@ -21,7 +21,6 @@ from torch import nn
21
21
  from transformers import PretrainedConfig
22
22
  from vllm.distributed import get_tensor_model_parallel_world_size
23
23
  from vllm.model_executor.layers.rotary_embedding import get_rope
24
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
25
24
 
26
25
  from sglang.srt.layers.activation import SiluAndMul
27
26
  from sglang.srt.layers.layernorm import RMSNorm
@@ -38,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
38
37
  VocabParallelEmbedding,
39
38
  )
40
39
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
40
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
41
41
 
42
42
 
43
43
  class InternLM2MLP(nn.Module):
@@ -251,7 +251,6 @@ class InternLM2ForCausalLM(nn.Module):
251
251
  self,
252
252
  config: PretrainedConfig,
253
253
  quant_config: Optional[QuantizationConfig] = None,
254
- cache_config=None,
255
254
  ) -> None:
256
255
  super().__init__()
257
256
  self.config = config
@@ -270,7 +269,7 @@ class InternLM2ForCausalLM(nn.Module):
270
269
  ) -> torch.Tensor:
271
270
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
272
271
  return self.logits_processor(
273
- input_ids, hidden_states, self.output.weight, forward_batch
272
+ input_ids, hidden_states, self.output, forward_batch
274
273
  )
275
274
 
276
275
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -29,7 +29,6 @@ class InternLM2ForRewardModel(nn.Module):
29
29
  self,
30
30
  config: PretrainedConfig,
31
31
  quant_config: Optional[QuantizationConfig] = None,
32
- cache_config=None,
33
32
  ) -> None:
34
33
  super().__init__()
35
34
  self.config = config
@@ -16,6 +16,7 @@
16
16
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
17
17
  """Inference-only LLaMA model compatible with HuggingFace weights."""
18
18
 
19
+ import logging
19
20
  from typing import Any, Dict, Iterable, Optional, Tuple
20
21
 
21
22
  import torch
@@ -23,7 +24,6 @@ from torch import nn
23
24
  from transformers import LlamaConfig
24
25
  from vllm.distributed import get_tensor_model_parallel_world_size
25
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
26
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
27
27
 
28
28
  from sglang.srt.layers.activation import SiluAndMul
29
29
  from sglang.srt.layers.layernorm import RMSNorm
@@ -43,7 +43,11 @@ from sglang.srt.layers.vocab_parallel_embedding import (
43
43
  )
44
44
  from sglang.srt.managers.schedule_batch import global_server_args_dict
45
45
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
46
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
46
47
  from sglang.srt.utils import make_layers
48
+ from sglang.utils import get_exception_traceback
49
+
50
+ logger = logging.getLogger(__name__)
47
51
 
48
52
 
49
53
  class LlamaMLP(nn.Module):
@@ -255,6 +259,7 @@ class LlamaModel(nn.Module):
255
259
  self.embed_tokens = VocabParallelEmbedding(
256
260
  config.vocab_size,
257
261
  config.hidden_size,
262
+ quant_config=quant_config,
258
263
  )
259
264
  self.layers = make_layers(
260
265
  config.num_hidden_layers,
@@ -295,16 +300,30 @@ class LlamaForCausalLM(nn.Module):
295
300
  self,
296
301
  config: LlamaConfig,
297
302
  quant_config: Optional[QuantizationConfig] = None,
298
- cache_config=None,
299
303
  ) -> None:
300
304
  super().__init__()
301
305
  self.config = config
302
306
  self.quant_config = quant_config
303
307
  self.torchao_config = global_server_args_dict["torchao_config"]
304
308
  self.model = LlamaModel(config, quant_config=quant_config)
305
- self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
309
+ # Llama 3.2 1B Insturct set tie_word_embeddings to True
310
+ # Llama 3.1 8B Insturct set tie_word_embeddings to False
311
+ if self.config.tie_word_embeddings:
312
+ self.lm_head = self.model.embed_tokens
313
+ else:
314
+ self.lm_head = ParallelLMHead(
315
+ config.vocab_size, config.hidden_size, quant_config=quant_config
316
+ )
306
317
  self.logits_processor = LogitsProcessor(config)
307
318
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
319
+ self.stacked_params_mapping = [
320
+ # (param_name, shard_name, shard_id)
321
+ (".qkv_proj", ".q_proj", "q"),
322
+ (".qkv_proj", ".k_proj", "k"),
323
+ (".qkv_proj", ".v_proj", "v"),
324
+ (".gate_up_proj", ".gate_proj", 0),
325
+ (".gate_up_proj", ".up_proj", 1),
326
+ ]
308
327
 
309
328
  @torch.no_grad()
310
329
  def forward(
@@ -318,7 +337,7 @@ class LlamaForCausalLM(nn.Module):
318
337
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
319
338
  if not get_embedding:
320
339
  return self.logits_processor(
321
- input_ids, hidden_states, self.lm_head.weight, forward_batch
340
+ input_ids, hidden_states, self.lm_head, forward_batch
322
341
  )
323
342
  else:
324
343
  return self.pooler(hidden_states, forward_batch)
@@ -349,15 +368,7 @@ class LlamaForCausalLM(nn.Module):
349
368
  return params_mapping.get(name, name)
350
369
 
351
370
  def get_module_name_from_weight_name(self, name):
352
- stacked_params_mapping = [
353
- # (param_name, shard_name, shard_id, num_shard)
354
- ("qkv_proj", "q_proj", "q", 3),
355
- ("qkv_proj", "k_proj", "k", 3),
356
- ("qkv_proj", "v_proj", "v", 3),
357
- ("gate_up_proj", "gate_proj", 0, 2),
358
- ("gate_up_proj", "up_proj", 1, 2),
359
- ]
360
- for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
371
+ for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
361
372
  if weight_name in name:
362
373
  return (
363
374
  name.replace(weight_name, param_name)[: -len(".weight")],
@@ -378,13 +389,8 @@ class LlamaForCausalLM(nn.Module):
378
389
  (".gate_up_proj", ".gate_proj", 0),
379
390
  (".gate_up_proj", ".up_proj", 1),
380
391
  ]
381
- params_dict = dict(self.named_parameters())
382
392
 
383
- load_tie_word_embeddings = (
384
- hasattr(self.config, "tie_word_embeddings")
385
- and self.config.tie_word_embeddings
386
- and "lm_head.weight" in params_dict
387
- )
393
+ params_dict = dict(self.named_parameters())
388
394
 
389
395
  for name, loaded_weight in weights:
390
396
  if "rotary_emb.inv_freq" in name or "projector" in name:
@@ -418,16 +424,80 @@ class LlamaForCausalLM(nn.Module):
418
424
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
419
425
  weight_loader(param, loaded_weight)
420
426
 
421
- if load_tie_word_embeddings and name == "model.embed_tokens.weight":
422
- embed_tokens_weight = loaded_weight
427
+ apply_torchao_config_(self, params_dict, set(["proj.weight"]))
423
428
 
424
- if load_tie_word_embeddings:
425
- # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
426
- param = self.lm_head.weight
427
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
428
- weight_loader(param, embed_tokens_weight)
429
+ def get_weights_by_name(
430
+ self, name: str, truncate_size: int = 100, tp_size: int = 1
431
+ ) -> Optional[torch.Tensor]:
432
+ """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
433
+
434
+ Only used for unit test with an unoptimized performance.
435
+ For optimized performance, please use torch.save and torch.load.
436
+ """
437
+ try:
438
+ if name == "lm_head.weight" and self.config.tie_word_embeddings:
439
+ logger.info(
440
+ "word embedding is tied for this model, return embed_tokens.weight as lm_head.weight."
441
+ )
442
+ return (
443
+ self.model.embed_tokens.weight.cpu()
444
+ .to(torch.float32)
445
+ .numpy()
446
+ .tolist()[:truncate_size]
447
+ )
429
448
 
430
- apply_torchao_config_(self, params_dict, set(["proj.weight"]))
449
+ mapped_name = name
450
+ mapped_shard_id = None
451
+ for param_name, weight_name, shard_id in self.stacked_params_mapping:
452
+ if weight_name in name:
453
+ mapped_name = name.replace(weight_name, param_name)
454
+ mapped_shard_id = shard_id
455
+ break
456
+ params_dict = dict(self.named_parameters())
457
+ param = params_dict[mapped_name]
458
+ if mapped_shard_id is not None:
459
+ if mapped_shard_id in ["q", "k", "v"]:
460
+ num_heads = self.config.num_attention_heads // tp_size
461
+ num_kv_heads = self.config.num_key_value_heads // tp_size
462
+ head_dim = (
463
+ self.config.hidden_size // self.config.num_attention_heads
464
+ )
465
+ if mapped_shard_id == "q":
466
+ offset = 0
467
+ size = num_heads * head_dim
468
+ elif mapped_shard_id == "k":
469
+ offset = num_heads * head_dim
470
+ size = num_kv_heads * head_dim
471
+ elif mapped_shard_id == "v":
472
+ offset = (num_heads + num_kv_heads) * head_dim
473
+ size = num_kv_heads * head_dim
474
+ weight = param.data.narrow(0, offset, size)
475
+ elif mapped_shard_id in [0, 1]:
476
+ intermediate_size = self.config.intermediate_size
477
+ slice_size = intermediate_size // tp_size
478
+ if mapped_shard_id == 0: # gate_proj
479
+ offset = 0
480
+ size = slice_size
481
+ elif mapped_shard_id == 1: # up_proj
482
+ offset = slice_size
483
+ size = slice_size
484
+
485
+ weight = param.data.narrow(0, offset, size)
486
+ else:
487
+ weight = param.data
488
+ else:
489
+ weight = param.data
490
+ if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
491
+ gathered_weights = [torch.zeros_like(weight) for _ in range(tp_size)]
492
+ torch.distributed.all_gather(gathered_weights, weight)
493
+ weight = torch.cat(gathered_weights, dim=1)
494
+ return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
495
+
496
+ except Exception:
497
+ logger.error(
498
+ f"Error getting weights by name {name} in LlamaForCausalLM: {get_exception_traceback()}"
499
+ )
500
+ return None
431
501
 
432
502
 
433
503
  class Phi3ForCausalLM(LlamaForCausalLM):
@@ -17,11 +17,11 @@ from typing import Iterable, Optional, Tuple
17
17
  import torch
18
18
  from torch import nn
19
19
  from transformers import LlamaConfig
20
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
21
20
 
22
21
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
23
22
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
24
23
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
25
25
  from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
26
26
 
27
27
 
@@ -30,7 +30,6 @@ class LlamaForClassification(nn.Module):
30
30
  self,
31
31
  config: LlamaConfig,
32
32
  quant_config: Optional[QuantizationConfig] = None,
33
- cache_config=None,
34
33
  ) -> None:
35
34
  super().__init__()
36
35
  self.config = config
@@ -3,10 +3,10 @@ from typing import Iterable, Tuple
3
3
  import torch
4
4
  from torch import nn
5
5
  from transformers import LlamaConfig
6
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
7
6
 
8
7
  from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
9
8
  from sglang.srt.model_executor.model_runner import ForwardBatch
9
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
10
10
  from sglang.srt.models.llama import LlamaModel
11
11
 
12
12
 
@@ -15,7 +15,6 @@ class LlamaEmbeddingModel(nn.Module):
15
15
  self,
16
16
  config: LlamaConfig,
17
17
  quant_config=None,
18
- cache_config=None,
19
18
  ) -> None:
20
19
  super().__init__()
21
20
  self.model = LlamaModel(config, quant_config=quant_config)
@@ -21,6 +21,7 @@ from transformers import LlamaConfig
21
21
  from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
22
22
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
23
23
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
24
25
  from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
25
26
 
26
27
 
@@ -29,7 +30,6 @@ class LlamaForSequenceClassification(nn.Module):
29
30
  self,
30
31
  config: LlamaConfig,
31
32
  quant_config: Optional[QuantizationConfig] = None,
32
- cache_config=None,
33
33
  ) -> None:
34
34
  super().__init__()
35
35
  self.config = config
@@ -84,9 +84,8 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
84
84
  self,
85
85
  config: LlamaConfig,
86
86
  quant_config: Optional[QuantizationConfig] = None,
87
- cache_config=None,
88
87
  ) -> None:
89
- super().__init__(config, quant_config, cache_config)
88
+ super().__init__(config, quant_config)
90
89
  self.weights = self.Weights(config.hidden_size, self.num_labels)
91
90
 
92
91
  @torch.no_grad()
@@ -29,7 +29,6 @@ from transformers import (
29
29
  SiglipVisionModel,
30
30
  )
31
31
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
32
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
33
32
 
34
33
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
35
34
  from sglang.srt.managers.schedule_batch import ImageInputs
@@ -39,6 +38,7 @@ from sglang.srt.mm_utils import (
39
38
  unpad_image_shape,
40
39
  )
41
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
41
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
42
42
  from sglang.srt.models.llama import LlamaForCausalLM
43
43
  from sglang.srt.models.mistral import MistralForCausalLM
44
44
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
@@ -57,7 +57,7 @@ class LlavaBaseForCausalLM(nn.Module):
57
57
  else:
58
58
  image_aspect_ratio = "anyres"
59
59
  offset_list = []
60
- for image_s in image_sizes:
60
+ for image_idx, image_s in enumerate(image_sizes):
61
61
  if len(image_sizes) > 16:
62
62
  # 2x2 pooling with stride 2
63
63
  new_image_feature_len = (
@@ -92,10 +92,6 @@ class LlavaBaseForCausalLM(nn.Module):
92
92
  new_w = int(new_w // times)
93
93
  new_image_feature_len += new_h * (new_w + 1)
94
94
 
95
- pad_ids = pad_values * (
96
- (new_image_feature_len + len(pad_values)) // len(pad_values)
97
- )
98
- # print("calculated new_image_feature_len: ", new_image_feature_len)
99
95
  try:
100
96
  offset = input_ids.index(self.config.image_token_index)
101
97
  except ValueError:
@@ -103,7 +99,7 @@ class LlavaBaseForCausalLM(nn.Module):
103
99
  # old_len + pad_len - 1, because we need to remove image_token_id
104
100
  input_ids = (
105
101
  input_ids[:offset]
106
- + pad_ids[:new_image_feature_len]
102
+ + [pad_values[image_idx]] * new_image_feature_len
107
103
  + input_ids[offset + 1 :]
108
104
  )
109
105
  offset_list.append(offset)
@@ -138,7 +134,6 @@ class LlavaBaseForCausalLM(nn.Module):
138
134
  image_inputs = forward_batch.image_inputs
139
135
 
140
136
  if forward_batch.forward_mode.is_extend():
141
- bs = forward_batch.batch_size
142
137
  # Got List[List[str]] extend it to List[str]
143
138
  # The length of the List should be equal to batch size
144
139
  modalities_list = []
@@ -146,11 +141,16 @@ class LlavaBaseForCausalLM(nn.Module):
146
141
  for im in image_inputs:
147
142
  if im and im.modalities is not None:
148
143
  modalities_list.extend(im.modalities)
149
- if im and im.image_offsets is not None:
144
+ if im and im.image_offsets:
150
145
  max_image_offset.append(max(im.image_offsets))
151
146
  else:
152
147
  max_image_offset.append(-1)
153
148
 
149
+ # Clamp input ids. This is because the input_ids for the image tokens are
150
+ # filled with the hash values of the image for the prefix matching in the radix attention.
151
+ # There values are useless because their embeddings will be replaced by vision embeddings anyway.
152
+ input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
153
+
154
154
  # Embed text inputs
155
155
  input_embeds = self.language_model.model.embed_tokens(input_ids)
156
156
 
@@ -158,6 +158,7 @@ class LlavaBaseForCausalLM(nn.Module):
158
158
  need_vision = start_positions <= np.array(max_image_offset)
159
159
 
160
160
  if need_vision.any():
161
+ bs = forward_batch.batch_size
161
162
  pixel_values = [
162
163
  image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
163
164
  ]
@@ -450,7 +451,6 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
450
451
  self,
451
452
  config: LlavaConfig,
452
453
  quant_config: Optional[QuantizationConfig] = None,
453
- cache_config=None,
454
454
  ) -> None:
455
455
  super().__init__()
456
456
 
@@ -472,7 +472,6 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
472
472
  self,
473
473
  config: LlavaConfig,
474
474
  quant_config: Optional[QuantizationConfig] = None,
475
- cache_config=None,
476
475
  ) -> None:
477
476
  super().__init__()
478
477
 
@@ -505,7 +504,6 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
505
504
  self,
506
505
  config: LlavaConfig,
507
506
  quant_config: Optional[QuantizationConfig] = None,
508
- cache_config=None,
509
507
  ) -> None:
510
508
  super().__init__()
511
509
 
@@ -20,11 +20,11 @@ import torch
20
20
  from torch import nn
21
21
  from transformers import CLIPVisionModel, LlavaConfig
22
22
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
23
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
24
23
 
25
24
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
26
25
  from sglang.srt.managers.schedule_batch import ImageInputs
27
26
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
27
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
28
28
  from sglang.srt.models.llama import LlamaForCausalLM
29
29
 
30
30
 
@@ -33,7 +33,6 @@ class LlavaVidForCausalLM(nn.Module):
33
33
  self,
34
34
  config: LlavaConfig,
35
35
  quant_config: Optional[QuantizationConfig] = None,
36
- cache_config=None,
37
36
  ) -> None:
38
37
  super().__init__()
39
38
  self.config = config