sglang 0.3.6.post3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_one_batch.py +4 -0
  2. sglang/bench_serving.py +13 -0
  3. sglang/check_env.py +1 -1
  4. sglang/srt/_custom_ops.py +118 -0
  5. sglang/srt/configs/device_config.py +17 -0
  6. sglang/srt/configs/load_config.py +84 -0
  7. sglang/srt/configs/model_config.py +161 -4
  8. sglang/srt/configs/qwen2vl.py +5 -8
  9. sglang/srt/constrained/outlines_backend.py +6 -1
  10. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  11. sglang/srt/distributed/__init__.py +3 -0
  12. sglang/srt/distributed/communication_op.py +34 -0
  13. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  14. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  15. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  16. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  17. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  18. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  21. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  22. sglang/srt/distributed/parallel_state.py +1275 -0
  23. sglang/srt/distributed/utils.py +223 -0
  24. sglang/srt/hf_transformers_utils.py +37 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +13 -15
  26. sglang/srt/layers/attention/torch_native_backend.py +285 -0
  27. sglang/srt/layers/fused_moe_patch.py +20 -11
  28. sglang/srt/layers/linear.py +1 -0
  29. sglang/srt/layers/logits_processor.py +17 -3
  30. sglang/srt/layers/quantization/__init__.py +34 -0
  31. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  32. sglang/srt/lora/lora.py +1 -1
  33. sglang/srt/managers/io_struct.py +48 -2
  34. sglang/srt/managers/schedule_batch.py +18 -14
  35. sglang/srt/managers/schedule_policy.py +7 -4
  36. sglang/srt/managers/scheduler.py +76 -20
  37. sglang/srt/managers/tokenizer_manager.py +166 -68
  38. sglang/srt/managers/tp_worker.py +36 -3
  39. sglang/srt/managers/tp_worker_overlap_thread.py +21 -3
  40. sglang/srt/model_executor/cuda_graph_runner.py +16 -7
  41. sglang/srt/model_executor/forward_batch_info.py +9 -4
  42. sglang/srt/model_executor/model_runner.py +136 -150
  43. sglang/srt/model_loader/__init__.py +34 -0
  44. sglang/srt/model_loader/loader.py +1139 -0
  45. sglang/srt/model_loader/utils.py +41 -0
  46. sglang/srt/model_loader/weight_utils.py +640 -0
  47. sglang/srt/models/baichuan.py +9 -10
  48. sglang/srt/models/chatglm.py +6 -15
  49. sglang/srt/models/commandr.py +2 -3
  50. sglang/srt/models/dbrx.py +2 -3
  51. sglang/srt/models/deepseek.py +4 -11
  52. sglang/srt/models/deepseek_v2.py +3 -11
  53. sglang/srt/models/exaone.py +2 -3
  54. sglang/srt/models/gemma.py +2 -6
  55. sglang/srt/models/gemma2.py +3 -14
  56. sglang/srt/models/gemma2_reward.py +0 -1
  57. sglang/srt/models/gpt2.py +5 -12
  58. sglang/srt/models/gpt_bigcode.py +6 -22
  59. sglang/srt/models/grok.py +3 -3
  60. sglang/srt/models/internlm2.py +2 -3
  61. sglang/srt/models/internlm2_reward.py +0 -1
  62. sglang/srt/models/llama.py +97 -27
  63. sglang/srt/models/llama_classification.py +1 -2
  64. sglang/srt/models/llama_embedding.py +1 -2
  65. sglang/srt/models/llama_reward.py +2 -3
  66. sglang/srt/models/llava.py +1 -4
  67. sglang/srt/models/llavavid.py +1 -2
  68. sglang/srt/models/minicpm.py +4 -7
  69. sglang/srt/models/minicpm3.py +6 -19
  70. sglang/srt/models/mixtral.py +12 -5
  71. sglang/srt/models/mixtral_quant.py +2 -3
  72. sglang/srt/models/mllama.py +3 -7
  73. sglang/srt/models/olmo.py +2 -8
  74. sglang/srt/models/olmo2.py +0 -1
  75. sglang/srt/models/olmoe.py +3 -5
  76. sglang/srt/models/phi3_small.py +8 -8
  77. sglang/srt/models/qwen.py +2 -3
  78. sglang/srt/models/qwen2.py +10 -9
  79. sglang/srt/models/qwen2_moe.py +4 -11
  80. sglang/srt/models/qwen2_vl.py +2 -6
  81. sglang/srt/models/registry.py +99 -0
  82. sglang/srt/models/stablelm.py +2 -3
  83. sglang/srt/models/torch_native_llama.py +6 -12
  84. sglang/srt/models/xverse.py +2 -4
  85. sglang/srt/models/xverse_moe.py +4 -11
  86. sglang/srt/models/yivl.py +2 -3
  87. sglang/srt/openai_api/adapter.py +9 -5
  88. sglang/srt/openai_api/protocol.py +1 -0
  89. sglang/srt/server.py +267 -170
  90. sglang/srt/server_args.py +65 -31
  91. sglang/srt/utils.py +245 -28
  92. sglang/test/test_utils.py +7 -0
  93. sglang/version.py +1 -1
  94. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/METADATA +1 -1
  95. sglang-0.4.0.dist-info/RECORD +184 -0
  96. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  97. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
  98. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
  99. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import (
34
34
  RowParallelLinear,
35
35
  )
36
36
  from vllm.model_executor.layers.rotary_embedding import get_rope
37
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
38
37
 
39
38
  from sglang.srt.layers.activation import SiluAndMul
40
39
  from sglang.srt.layers.layernorm import RMSNorm
@@ -46,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
46
45
  VocabParallelEmbedding,
47
46
  )
48
47
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
48
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
49
49
 
50
50
 
51
51
  def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
@@ -329,7 +329,6 @@ class BaiChuanBaseForCausalLM(nn.Module):
329
329
  self,
330
330
  config: PretrainedConfig,
331
331
  position_embedding: str,
332
- cache_config=None,
333
332
  quant_config: Optional[QuantizationConfig] = None,
334
333
  ):
335
334
  super().__init__()
@@ -338,11 +337,12 @@ class BaiChuanBaseForCausalLM(nn.Module):
338
337
 
339
338
  self.quant_config = quant_config
340
339
  self.model = BaiChuanModel(config, position_embedding, quant_config)
341
- self.lm_head = ParallelLMHead(
342
- config.vocab_size, config.hidden_size, quant_config=quant_config
343
- )
344
340
  if self.config.tie_word_embeddings:
345
- self.lm_head.weight = self.model.embed_tokens.weight
341
+ self.lm_head = self.model.embed_tokens
342
+ else:
343
+ self.lm_head = ParallelLMHead(
344
+ config.vocab_size, config.hidden_size, quant_config=quant_config
345
+ )
346
346
  self.logits_processor = LogitsProcessor(config)
347
347
 
348
348
  def forward(
@@ -353,7 +353,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
353
353
  ) -> torch.Tensor:
354
354
  hidden_states = self.model(input_ids, positions, forward_batch)
355
355
  return self.logits_processor(
356
- input_ids, hidden_states, self.lm_head.weight, forward_batch
356
+ input_ids, hidden_states, self.lm_head, forward_batch
357
357
  )
358
358
 
359
359
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -403,13 +403,12 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
403
403
  def __init__(
404
404
  self,
405
405
  config,
406
- cache_config=None,
407
406
  quant_config: Optional[QuantizationConfig] = None,
408
407
  ):
409
408
  if config.hidden_size == 4096: # baichuan2 7b
410
- super().__init__(config, "ROPE", cache_config, quant_config)
409
+ super().__init__(config, "ROPE", quant_config)
411
410
  else: # baichuan 13b, baichuan2 13b
412
- super().__init__(config, "ALIBI", cache_config, quant_config)
411
+ super().__init__(config, "ALIBI", quant_config)
413
412
 
414
413
 
415
414
  EntryClass = [BaichuanForCausalLM]
@@ -23,7 +23,6 @@ from torch import nn
23
23
  from torch.nn import LayerNorm
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
25
  from vllm.model_executor.layers.rotary_embedding import get_rope
26
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
27
26
  from vllm.transformers_utils.configs import ChatGLMConfig
28
27
 
29
28
  from sglang.srt.layers.activation import SiluAndMul
@@ -41,6 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
41
40
  VocabParallelEmbedding,
42
41
  )
43
42
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
44
44
 
45
45
  LoraConfig = None
46
46
 
@@ -50,7 +50,6 @@ class GLMAttention(nn.Module):
50
50
  self,
51
51
  config,
52
52
  layer_id: int = 0,
53
- cache_config=None,
54
53
  quant_config: Optional[QuantizationConfig] = None,
55
54
  ):
56
55
  super().__init__()
@@ -186,7 +185,6 @@ class GLMBlock(nn.Module):
186
185
  self,
187
186
  config,
188
187
  layer_id: int,
189
- cache_config=None,
190
188
  quant_config: Optional[QuantizationConfig] = None,
191
189
  ):
192
190
  super().__init__()
@@ -203,7 +201,7 @@ class GLMBlock(nn.Module):
203
201
  )
204
202
 
205
203
  # Self attention.
206
- self.self_attention = GLMAttention(config, layer_id, cache_config, quant_config)
204
+ self.self_attention = GLMAttention(config, layer_id, quant_config)
207
205
  self.hidden_dropout = config.hidden_dropout
208
206
 
209
207
  # Layernorm on the attention output
@@ -258,7 +256,6 @@ class GLMTransformer(nn.Module):
258
256
  def __init__(
259
257
  self,
260
258
  config,
261
- cache_config=None,
262
259
  quant_config: Optional[QuantizationConfig] = None,
263
260
  ):
264
261
  super().__init__()
@@ -269,10 +266,7 @@ class GLMTransformer(nn.Module):
269
266
 
270
267
  # Transformer layers.
271
268
  self.layers = nn.ModuleList(
272
- [
273
- GLMBlock(config, i, cache_config, quant_config)
274
- for i in range(self.num_layers)
275
- ]
269
+ [GLMBlock(config, i, quant_config) for i in range(self.num_layers)]
276
270
  )
277
271
 
278
272
  if self.post_layer_norm:
@@ -306,7 +300,6 @@ class ChatGLMM(nn.Module):
306
300
  def __init__(
307
301
  self,
308
302
  config,
309
- cache_config=None,
310
303
  quant_config: Optional[QuantizationConfig] = None,
311
304
  ):
312
305
  super().__init__()
@@ -318,7 +311,7 @@ class ChatGLMM(nn.Module):
318
311
  self.num_layers = config.num_layers
319
312
  self.multi_query_group_num = config.multi_query_group_num
320
313
  self.kv_channels = config.kv_channels
321
- self.encoder = GLMTransformer(config, cache_config, quant_config)
314
+ self.encoder = GLMTransformer(config, quant_config)
322
315
 
323
316
  self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size)
324
317
 
@@ -357,15 +350,13 @@ class ChatGLMForCausalLM(nn.Module):
357
350
  def __init__(
358
351
  self,
359
352
  config: ChatGLMConfig,
360
- cache_config=None,
361
353
  quant_config: Optional[QuantizationConfig] = None,
362
- lora_config: Optional[LoraConfig] = None,
363
354
  ):
364
355
  super().__init__()
365
356
  self.config: ChatGLMConfig = config
366
357
  self.quant_config = quant_config
367
358
  self.max_position_embeddings = getattr(config, "max_sequence_length", 8192)
368
- self.transformer = ChatGLMM(config, cache_config, quant_config)
359
+ self.transformer = ChatGLMM(config, quant_config)
369
360
  self.lm_head = self.transformer.output_layer
370
361
  self.logits_processor = LogitsProcessor(config)
371
362
 
@@ -378,7 +369,7 @@ class ChatGLMForCausalLM(nn.Module):
378
369
  ) -> torch.Tensor:
379
370
  hidden_states = self.transformer(input_ids, positions, forward_batch)
380
371
  return self.logits_processor(
381
- input_ids, hidden_states, self.lm_head.weight, forward_batch
372
+ input_ids, hidden_states, self.lm_head, forward_batch
382
373
  )
383
374
 
384
375
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -49,7 +49,6 @@ from vllm.distributed import (
49
49
  get_tensor_model_parallel_world_size,
50
50
  )
51
51
  from vllm.model_executor.layers.rotary_embedding import get_rope
52
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
53
52
 
54
53
  from sglang.srt.layers.activation import SiluAndMul
55
54
  from sglang.srt.layers.linear import (
@@ -62,6 +61,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
62
61
  from sglang.srt.layers.radix_attention import RadixAttention
63
62
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
64
63
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
64
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
65
65
  from sglang.srt.utils import set_weight_attrs
66
66
 
67
67
 
@@ -318,7 +318,6 @@ class CohereForCausalLM(nn.Module):
318
318
  self,
319
319
  config: PretrainedConfig,
320
320
  quant_config: Optional[QuantizationConfig] = None,
321
- cache_config=None,
322
321
  ) -> None:
323
322
  super().__init__()
324
323
  self.config = config
@@ -339,7 +338,7 @@ class CohereForCausalLM(nn.Module):
339
338
  forward_batch,
340
339
  )
341
340
  return self.logits_processor(
342
- input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
341
+ input_ids, hidden_states, self.model.embed_tokens, forward_batch
343
342
  )
344
343
 
345
344
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
sglang/srt/models/dbrx.py CHANGED
@@ -25,7 +25,6 @@ from vllm.distributed import (
25
25
  tensor_model_parallel_all_reduce,
26
26
  )
27
27
  from vllm.model_executor.layers.rotary_embedding import get_rope
28
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
29
28
  from vllm.transformers_utils.configs.dbrx import DbrxConfig
30
29
 
31
30
  from sglang.srt.layers.fused_moe_triton import fused_moe
@@ -43,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
43
42
  VocabParallelEmbedding,
44
43
  )
45
44
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
45
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
46
46
  from sglang.srt.utils import set_weight_attrs
47
47
 
48
48
 
@@ -366,7 +366,6 @@ class DbrxForCausalLM(nn.Module):
366
366
  self,
367
367
  config: DbrxConfig,
368
368
  quant_config: Optional[QuantizationConfig] = None,
369
- cache_config=None,
370
369
  ):
371
370
  super().__init__()
372
371
  self.config = config
@@ -390,7 +389,7 @@ class DbrxForCausalLM(nn.Module):
390
389
  ) -> torch.Tensor:
391
390
  hidden_states = self.transformer(input_ids, positions, forward_batch)
392
391
  return self.logits_processor(
393
- input_ids, hidden_states, self.lm_head.weight, forward_batch
392
+ input_ids, hidden_states, self.lm_head, forward_batch
394
393
  )
395
394
 
396
395
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -27,7 +27,6 @@ from vllm.distributed import (
27
27
  tensor_model_parallel_all_reduce,
28
28
  )
29
29
  from vllm.model_executor.layers.rotary_embedding import get_rope
30
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
30
 
32
31
  from sglang.srt.layers.activation import SiluAndMul
33
32
  from sglang.srt.layers.fused_moe_triton import fused_moe
@@ -46,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
46
45
  VocabParallelEmbedding,
47
46
  )
48
47
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
48
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
49
49
 
50
50
 
51
51
  class DeepseekMLP(nn.Module):
@@ -184,7 +184,6 @@ class DeepseekAttention(nn.Module):
184
184
  rope_theta: float = 10000,
185
185
  rope_scaling: Optional[Dict[str, Any]] = None,
186
186
  max_position_embeddings: int = 8192,
187
- cache_config=None,
188
187
  quant_config: Optional[QuantizationConfig] = None,
189
188
  ) -> None:
190
189
  super().__init__()
@@ -261,7 +260,6 @@ class DeepseekDecoderLayer(nn.Module):
261
260
  self,
262
261
  config: PretrainedConfig,
263
262
  layer_id: int,
264
- cache_config=None,
265
263
  quant_config: Optional[QuantizationConfig] = None,
266
264
  ) -> None:
267
265
  super().__init__()
@@ -277,7 +275,6 @@ class DeepseekDecoderLayer(nn.Module):
277
275
  rope_theta=rope_theta,
278
276
  rope_scaling=rope_scaling,
279
277
  max_position_embeddings=max_position_embeddings,
280
- cache_config=cache_config,
281
278
  quant_config=quant_config,
282
279
  )
283
280
  if (
@@ -330,7 +327,6 @@ class DeepseekModel(nn.Module):
330
327
  def __init__(
331
328
  self,
332
329
  config: PretrainedConfig,
333
- cache_config=None,
334
330
  quant_config: Optional[QuantizationConfig] = None,
335
331
  ) -> None:
336
332
  super().__init__()
@@ -343,9 +339,7 @@ class DeepseekModel(nn.Module):
343
339
  )
344
340
  self.layers = nn.ModuleList(
345
341
  [
346
- DeepseekDecoderLayer(
347
- config, layer_id, cache_config, quant_config=quant_config
348
- )
342
+ DeepseekDecoderLayer(config, layer_id, quant_config=quant_config)
349
343
  for layer_id in range(config.num_hidden_layers)
350
344
  ]
351
345
  )
@@ -373,13 +367,12 @@ class DeepseekForCausalLM(nn.Module):
373
367
  def __init__(
374
368
  self,
375
369
  config: PretrainedConfig,
376
- cache_config=None,
377
370
  quant_config: Optional[QuantizationConfig] = None,
378
371
  ) -> None:
379
372
  super().__init__()
380
373
  self.config = config
381
374
  self.quant_config = quant_config
382
- self.model = DeepseekModel(config, cache_config, quant_config)
375
+ self.model = DeepseekModel(config, quant_config)
383
376
  self.lm_head = ParallelLMHead(
384
377
  config.vocab_size, config.hidden_size, quant_config=quant_config
385
378
  )
@@ -394,7 +387,7 @@ class DeepseekForCausalLM(nn.Module):
394
387
  ) -> torch.Tensor:
395
388
  hidden_states = self.model(input_ids, positions, forward_batch)
396
389
  return self.logits_processor(
397
- input_ids, hidden_states, self.lm_head.weight, forward_batch
390
+ input_ids, hidden_states, self.lm_head, forward_batch
398
391
  )
399
392
 
400
393
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -28,7 +28,6 @@ from vllm.distributed import (
28
28
  tensor_model_parallel_all_reduce,
29
29
  )
30
30
  from vllm.model_executor.layers.rotary_embedding import get_rope
31
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32
31
 
33
32
  from sglang.srt.layers.activation import SiluAndMul
34
33
  from sglang.srt.layers.fused_moe_triton import FusedMoE
@@ -48,6 +47,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
48
47
  )
49
48
  from sglang.srt.managers.schedule_batch import global_server_args_dict
50
49
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
51
51
  from sglang.srt.utils import is_flashinfer_available
52
52
 
53
53
  if is_flashinfer_available():
@@ -189,7 +189,6 @@ class DeepseekV2Attention(nn.Module):
189
189
  rope_theta: float = 10000,
190
190
  rope_scaling: Optional[Dict[str, Any]] = None,
191
191
  max_position_embeddings: int = 8192,
192
- cache_config=None,
193
192
  quant_config: Optional[QuantizationConfig] = None,
194
193
  layer_id=None,
195
194
  ) -> None:
@@ -337,7 +336,6 @@ class DeepseekV2AttentionMLA(nn.Module):
337
336
  rope_theta: float = 10000,
338
337
  rope_scaling: Optional[Dict[str, Any]] = None,
339
338
  max_position_embeddings: int = 8192,
340
- cache_config=None,
341
339
  quant_config: Optional[QuantizationConfig] = None,
342
340
  layer_id=None,
343
341
  use_dp=False,
@@ -568,7 +566,6 @@ class DeepseekV2DecoderLayer(nn.Module):
568
566
  self,
569
567
  config: PretrainedConfig,
570
568
  layer_id: int,
571
- cache_config=None,
572
569
  quant_config: Optional[QuantizationConfig] = None,
573
570
  ) -> None:
574
571
  super().__init__()
@@ -599,7 +596,6 @@ class DeepseekV2DecoderLayer(nn.Module):
599
596
  rope_theta=rope_theta,
600
597
  rope_scaling=rope_scaling,
601
598
  max_position_embeddings=max_position_embeddings,
602
- cache_config=cache_config,
603
599
  quant_config=quant_config,
604
600
  layer_id=layer_id,
605
601
  use_dp=self.enable_dp_attention,
@@ -619,7 +615,6 @@ class DeepseekV2DecoderLayer(nn.Module):
619
615
  rope_theta=rope_theta,
620
616
  rope_scaling=rope_scaling,
621
617
  max_position_embeddings=max_position_embeddings,
622
- cache_config=cache_config,
623
618
  quant_config=quant_config,
624
619
  layer_id=layer_id,
625
620
  )
@@ -685,7 +680,6 @@ class DeepseekV2Model(nn.Module):
685
680
  def __init__(
686
681
  self,
687
682
  config: PretrainedConfig,
688
- cache_config=None,
689
683
  quant_config: Optional[QuantizationConfig] = None,
690
684
  ) -> None:
691
685
  super().__init__()
@@ -702,7 +696,6 @@ class DeepseekV2Model(nn.Module):
702
696
  DeepseekV2DecoderLayer(
703
697
  config,
704
698
  layer_id,
705
- cache_config=cache_config,
706
699
  quant_config=quant_config,
707
700
  )
708
701
  for layer_id in range(config.num_hidden_layers)
@@ -733,13 +726,12 @@ class DeepseekV2ForCausalLM(nn.Module):
733
726
  def __init__(
734
727
  self,
735
728
  config: PretrainedConfig,
736
- cache_config=None,
737
729
  quant_config: Optional[QuantizationConfig] = None,
738
730
  ) -> None:
739
731
  super().__init__()
740
732
  self.config = config
741
733
  self.quant_config = quant_config
742
- self.model = DeepseekV2Model(config, cache_config, quant_config)
734
+ self.model = DeepseekV2Model(config, quant_config)
743
735
  if global_server_args_dict["enable_dp_attention"]:
744
736
  self.lm_head = ReplicatedLinear(
745
737
  config.hidden_size,
@@ -763,7 +755,7 @@ class DeepseekV2ForCausalLM(nn.Module):
763
755
  hidden_states = self.model(input_ids, positions, forward_batch)
764
756
  if not forward_batch.forward_mode.is_idle():
765
757
  return self.logits_processor(
766
- input_ids, hidden_states, self.lm_head.weight, forward_batch
758
+ input_ids, hidden_states, self.lm_head, forward_batch
767
759
  )
768
760
 
769
761
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -22,7 +22,6 @@ import torch
22
22
  from torch import nn
23
23
  from vllm.distributed import get_tensor_model_parallel_world_size
24
24
  from vllm.model_executor.layers.rotary_embedding import get_rope
25
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
25
 
27
26
  from sglang.srt.layers.activation import SiluAndMul
28
27
  from sglang.srt.layers.layernorm import RMSNorm
@@ -39,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
39
38
  VocabParallelEmbedding,
40
39
  )
41
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
41
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
42
42
 
43
43
 
44
44
  class ExaoneGatedMLP(nn.Module):
@@ -293,7 +293,6 @@ class ExaoneForCausalLM(nn.Module):
293
293
  self,
294
294
  config,
295
295
  quant_config: Optional[QuantizationConfig] = None,
296
- cache_config=None,
297
296
  ) -> None:
298
297
  super().__init__()
299
298
  self.config = config
@@ -314,7 +313,7 @@ class ExaoneForCausalLM(nn.Module):
314
313
  input_ids, positions, forward_batch, input_embeds
315
314
  )
316
315
  return self.logits_processor(
317
- input_ids, hidden_states, self.lm_head.weight, forward_batch
316
+ input_ids, hidden_states, self.lm_head, forward_batch
318
317
  )
319
318
 
320
319
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -21,10 +21,8 @@ from typing import Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import PretrainedConfig
24
- from vllm.config import LoRAConfig
25
24
  from vllm.distributed import get_tensor_model_parallel_world_size
26
25
  from vllm.model_executor.layers.rotary_embedding import get_rope
27
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28
26
 
29
27
  from sglang.srt.layers.activation import GeluAndMul
30
28
  from sglang.srt.layers.layernorm import RMSNorm
@@ -38,6 +36,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
36
  from sglang.srt.layers.radix_attention import RadixAttention
39
37
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
40
38
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
39
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
41
40
 
42
41
 
43
42
  class GemmaMLP(nn.Module):
@@ -278,10 +277,7 @@ class GemmaForCausalLM(nn.Module):
278
277
  self,
279
278
  config: PretrainedConfig,
280
279
  quant_config: Optional[QuantizationConfig] = None,
281
- lora_config: Optional[LoRAConfig] = None,
282
- cache_config=None,
283
280
  ) -> None:
284
- del lora_config # Unused.
285
281
  super().__init__()
286
282
  self.config = config
287
283
  self.quant_config = quant_config
@@ -298,7 +294,7 @@ class GemmaForCausalLM(nn.Module):
298
294
  ) -> torch.Tensor:
299
295
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
300
296
  return self.logits_processor(
301
- input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
297
+ input_ids, hidden_states, self.model.embed_tokens, forward_batch
302
298
  )
303
299
 
304
300
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -20,12 +20,8 @@ from typing import Iterable, Optional, Set, Tuple, Union
20
20
  import torch
21
21
  from torch import nn
22
22
  from transformers import PretrainedConfig
23
- from vllm.config import LoRAConfig
24
23
  from vllm.distributed import get_tensor_model_parallel_world_size
25
24
 
26
- # from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
27
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28
-
29
25
  from sglang.srt.layers.activation import GeluAndMul
30
26
  from sglang.srt.layers.layernorm import GemmaRMSNorm
31
27
  from sglang.srt.layers.linear import (
@@ -38,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
34
  from sglang.srt.layers.radix_attention import RadixAttention
39
35
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
40
36
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
37
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
41
38
  from sglang.srt.utils import make_layers
42
39
 
43
40
 
@@ -106,7 +103,6 @@ class Gemma2Attention(nn.Module):
106
103
  head_dim: int,
107
104
  max_position_embeddings: int,
108
105
  rope_theta: float,
109
- cache_config=None,
110
106
  quant_config: Optional[QuantizationConfig] = None,
111
107
  ) -> None:
112
108
  super().__init__()
@@ -191,7 +187,6 @@ class Gemma2DecoderLayer(nn.Module):
191
187
  self,
192
188
  layer_id: int,
193
189
  config: PretrainedConfig,
194
- cache_config=None,
195
190
  quant_config: Optional[QuantizationConfig] = None,
196
191
  ) -> None:
197
192
  super().__init__()
@@ -205,7 +200,6 @@ class Gemma2DecoderLayer(nn.Module):
205
200
  head_dim=config.head_dim,
206
201
  max_position_embeddings=config.max_position_embeddings,
207
202
  rope_theta=config.rope_theta,
208
- cache_config=cache_config,
209
203
  quant_config=quant_config,
210
204
  )
211
205
  self.hidden_size = config.hidden_size
@@ -258,7 +252,6 @@ class Gemma2Model(nn.Module):
258
252
  def __init__(
259
253
  self,
260
254
  config: PretrainedConfig,
261
- cache_config=None,
262
255
  quant_config: Optional[QuantizationConfig] = None,
263
256
  ) -> None:
264
257
  super().__init__()
@@ -273,7 +266,6 @@ class Gemma2Model(nn.Module):
273
266
  lambda idx, prefix: Gemma2DecoderLayer(
274
267
  layer_id=idx,
275
268
  config=config,
276
- cache_config=cache_config,
277
269
  quant_config=quant_config,
278
270
  ),
279
271
  prefix="",
@@ -342,15 +334,12 @@ class Gemma2ForCausalLM(nn.Module):
342
334
  def __init__(
343
335
  self,
344
336
  config: PretrainedConfig,
345
- cache_config=None,
346
337
  quant_config: Optional[QuantizationConfig] = None,
347
- lora_config: Optional[LoRAConfig] = None,
348
338
  ) -> None:
349
- del lora_config # Unused.
350
339
  super().__init__()
351
340
  self.config = config
352
341
  self.quant_config = quant_config
353
- self.model = Gemma2Model(config, cache_config, quant_config)
342
+ self.model = Gemma2Model(config, quant_config)
354
343
  self.logits_processor = LogitsProcessor(config)
355
344
 
356
345
  @torch.no_grad()
@@ -363,7 +352,7 @@ class Gemma2ForCausalLM(nn.Module):
363
352
  ) -> torch.Tensor:
364
353
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
365
354
  return self.logits_processor(
366
- input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
355
+ input_ids, hidden_states, self.model.embed_tokens, forward_batch
367
356
  )
368
357
 
369
358
  def get_attention_sliding_window_size(self):
@@ -29,7 +29,6 @@ class Gemma2ForSequenceClassification(nn.Module):
29
29
  self,
30
30
  config: Gemma2Config,
31
31
  quant_config: Optional[QuantizationConfig] = None,
32
- cache_config=None,
33
32
  ) -> None:
34
33
  super().__init__()
35
34
  self.config = config
sglang/srt/models/gpt2.py CHANGED
@@ -22,11 +22,9 @@ from typing import Iterable, List, Optional, Tuple
22
22
  import torch
23
23
  from torch import nn
24
24
  from transformers import GPT2Config
25
- from vllm.config import CacheConfig
26
25
  from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
27
26
  from vllm.model_executor.layers.activation import get_act_fn
28
27
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
29
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
28
 
31
29
  # from sglang.srt.layers.activation import get_act_fn
32
30
  from sglang.srt.layers.linear import (
@@ -39,6 +37,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
37
  from sglang.srt.layers.radix_attention import RadixAttention
40
38
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
41
39
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
40
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
42
41
 
43
42
 
44
43
  class GPT2Attention(nn.Module):
@@ -47,7 +46,6 @@ class GPT2Attention(nn.Module):
47
46
  self,
48
47
  layer_id: int,
49
48
  config: GPT2Config,
50
- cache_config=None,
51
49
  quant_config: Optional[QuantizationConfig] = None,
52
50
  prefix: str = "",
53
51
  ):
@@ -140,7 +138,6 @@ class GPT2Block(nn.Module):
140
138
  self,
141
139
  layer_id: int,
142
140
  config: GPT2Config,
143
- cache_config=None,
144
141
  quant_config: Optional[QuantizationConfig] = None,
145
142
  prefix: str = "",
146
143
  ):
@@ -150,7 +147,7 @@ class GPT2Block(nn.Module):
150
147
 
151
148
  self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
152
149
  self.attn = GPT2Attention(
153
- layer_id, config, cache_config, quant_config, prefix=f"{prefix}.attn"
150
+ layer_id, config, quant_config, prefix=f"{prefix}.attn"
154
151
  )
155
152
  self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
156
153
  self.mlp = GPT2MLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
@@ -182,7 +179,6 @@ class GPT2Model(nn.Module):
182
179
  def __init__(
183
180
  self,
184
181
  config: GPT2Config,
185
- cache_config=None,
186
182
  quant_config: Optional[QuantizationConfig] = None,
187
183
  prefix: str = "",
188
184
  ):
@@ -196,7 +192,7 @@ class GPT2Model(nn.Module):
196
192
  self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
197
193
  self.h = nn.ModuleList(
198
194
  [
199
- GPT2Block(i, config, cache_config, quant_config)
195
+ GPT2Block(i, config, quant_config)
200
196
  for i in range(config.num_hidden_layers)
201
197
  ]
202
198
  )
@@ -226,15 +222,12 @@ class GPT2LMHeadModel(nn.Module):
226
222
  def __init__(
227
223
  self,
228
224
  config: GPT2Config,
229
- cache_config=None,
230
225
  quant_config: Optional[QuantizationConfig] = None,
231
226
  ):
232
227
  super().__init__()
233
228
  self.config = config
234
229
  self.quant_config = quant_config
235
- self.transformer = GPT2Model(
236
- config, cache_config, quant_config, prefix="transformer"
237
- )
230
+ self.transformer = GPT2Model(config, quant_config, prefix="transformer")
238
231
  self.lm_head = self.transformer.wte
239
232
 
240
233
  self.logits_processor = LogitsProcessor(config)
@@ -247,7 +240,7 @@ class GPT2LMHeadModel(nn.Module):
247
240
  ) -> torch.Tensor:
248
241
  hidden_states = self.transformer(input_ids, positions, forward_batch)
249
242
  return self.logits_processor(
250
- input_ids, hidden_states, self.lm_head.weight, forward_batch
243
+ input_ids, hidden_states, self.lm_head, forward_batch
251
244
  )
252
245
 
253
246
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):