sglang 0.3.4.post2__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_latency.py +3 -3
  3. sglang/bench_server_latency.py +2 -3
  4. sglang/bench_serving.py +92 -0
  5. sglang/global_config.py +9 -3
  6. sglang/lang/chat_template.py +50 -25
  7. sglang/lang/interpreter.py +9 -1
  8. sglang/lang/ir.py +11 -2
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/configs/model_config.py +51 -13
  11. sglang/srt/constrained/__init__.py +18 -0
  12. sglang/srt/constrained/bnf_cache.py +61 -0
  13. sglang/srt/constrained/grammar.py +190 -0
  14. sglang/srt/hf_transformers_utils.py +6 -5
  15. sglang/srt/layers/attention/triton_ops/decode_attention.py +110 -30
  16. sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
  17. sglang/srt/layers/fused_moe/fused_moe.py +4 -3
  18. sglang/srt/layers/fused_moe/layer.py +28 -0
  19. sglang/srt/layers/quantization/base_config.py +16 -1
  20. sglang/srt/layers/vocab_parallel_embedding.py +486 -0
  21. sglang/srt/managers/data_parallel_controller.py +7 -6
  22. sglang/srt/managers/detokenizer_manager.py +9 -11
  23. sglang/srt/managers/image_processor.py +4 -3
  24. sglang/srt/managers/io_struct.py +70 -78
  25. sglang/srt/managers/schedule_batch.py +33 -49
  26. sglang/srt/managers/schedule_policy.py +24 -13
  27. sglang/srt/managers/scheduler.py +137 -80
  28. sglang/srt/managers/tokenizer_manager.py +224 -336
  29. sglang/srt/managers/tp_worker.py +5 -5
  30. sglang/srt/mem_cache/flush_cache.py +1 -1
  31. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  32. sglang/srt/model_executor/model_runner.py +8 -17
  33. sglang/srt/models/baichuan.py +4 -4
  34. sglang/srt/models/chatglm.py +4 -4
  35. sglang/srt/models/commandr.py +1 -1
  36. sglang/srt/models/dbrx.py +5 -5
  37. sglang/srt/models/deepseek.py +4 -4
  38. sglang/srt/models/deepseek_v2.py +4 -4
  39. sglang/srt/models/exaone.py +4 -4
  40. sglang/srt/models/gemma.py +1 -1
  41. sglang/srt/models/gemma2.py +1 -1
  42. sglang/srt/models/gpt2.py +287 -0
  43. sglang/srt/models/gpt_bigcode.py +1 -1
  44. sglang/srt/models/grok.py +4 -4
  45. sglang/srt/models/internlm2.py +4 -4
  46. sglang/srt/models/llama.py +15 -7
  47. sglang/srt/models/llama_embedding.py +2 -10
  48. sglang/srt/models/llama_reward.py +5 -0
  49. sglang/srt/models/minicpm.py +4 -4
  50. sglang/srt/models/minicpm3.py +4 -4
  51. sglang/srt/models/mixtral.py +7 -5
  52. sglang/srt/models/mixtral_quant.py +4 -4
  53. sglang/srt/models/mllama.py +5 -5
  54. sglang/srt/models/olmo.py +4 -4
  55. sglang/srt/models/olmoe.py +4 -4
  56. sglang/srt/models/qwen.py +4 -4
  57. sglang/srt/models/qwen2.py +4 -4
  58. sglang/srt/models/qwen2_moe.py +4 -4
  59. sglang/srt/models/qwen2_vl.py +4 -8
  60. sglang/srt/models/stablelm.py +4 -4
  61. sglang/srt/models/torch_native_llama.py +4 -4
  62. sglang/srt/models/xverse.py +4 -4
  63. sglang/srt/models/xverse_moe.py +4 -4
  64. sglang/srt/openai_api/adapter.py +52 -66
  65. sglang/srt/sampling/sampling_batch_info.py +7 -13
  66. sglang/srt/server.py +31 -35
  67. sglang/srt/server_args.py +34 -5
  68. sglang/srt/utils.py +40 -56
  69. sglang/test/runners.py +2 -1
  70. sglang/test/test_utils.py +73 -25
  71. sglang/utils.py +62 -1
  72. sglang/version.py +1 -1
  73. sglang-0.3.5.dist-info/METADATA +344 -0
  74. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/RECORD +77 -73
  75. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/WHEEL +1 -1
  76. sglang-0.3.4.post2.dist-info/METADATA +0 -899
  77. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/LICENSE +0 -0
  78. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,6 @@ limitations under the License.
15
15
 
16
16
  """A tensor parallel worker."""
17
17
 
18
- import json
19
18
  import logging
20
19
  from typing import Optional
21
20
 
@@ -26,7 +25,7 @@ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_a
26
25
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
27
26
  from sglang.srt.model_executor.model_runner import ModelRunner
28
27
  from sglang.srt.server_args import ServerArgs
29
- from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_seed
28
+ from sglang.srt.utils import broadcast_pyobj, set_random_seed
30
29
 
31
30
  logger = logging.getLogger(__name__)
32
31
 
@@ -48,9 +47,10 @@ class TpModelWorker:
48
47
  # Init model and tokenizer
49
48
  self.model_config = ModelConfig(
50
49
  server_args.model_path,
51
- server_args.trust_remote_code,
50
+ trust_remote_code=server_args.trust_remote_code,
52
51
  context_length=server_args.context_length,
53
- model_override_args=json.loads(server_args.json_model_override_args),
52
+ model_override_args=server_args.json_model_override_args,
53
+ is_embedding=server_args.is_embedding,
54
54
  )
55
55
  self.model_runner = ModelRunner(
56
56
  model_config=self.model_config,
@@ -64,7 +64,7 @@ class TpModelWorker:
64
64
  if server_args.skip_tokenizer_init:
65
65
  self.tokenizer = self.processor = None
66
66
  else:
67
- if is_multimodal_model(self.model_config.hf_config.architectures):
67
+ if self.model_config.is_multimodal:
68
68
  self.processor = get_processor(
69
69
  server_args.tokenizer_path,
70
70
  tokenizer_mode=server_args.tokenizer_mode,
@@ -29,5 +29,5 @@ if __name__ == "__main__":
29
29
  parser.add_argument("--url", type=str, default="http://localhost:30000")
30
30
  args = parser.parse_args()
31
31
 
32
- response = requests.get(args.url + "/flush_cache")
32
+ response = requests.post(args.url + "/flush_cache")
33
33
  assert response.status_code == 200
@@ -113,18 +113,21 @@ class CudaGraphRunner:
113
113
  self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
114
114
 
115
115
  # Batch sizes to capture
116
- if self.model_runner.server_args.disable_cuda_graph_padding:
116
+ if model_runner.server_args.disable_cuda_graph_padding:
117
117
  self.capture_bs = list(range(1, 32)) + [64, 128]
118
118
  else:
119
- self.capture_bs = [1, 2, 3, 4] + [i * 8 for i in range(1, 21)]
119
+ self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
120
120
  self.capture_bs = [
121
- bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
121
+ bs
122
+ for bs in self.capture_bs
123
+ if bs <= model_runner.req_to_token_pool.size
124
+ and bs <= model_runner.server_args.cuda_graph_max_bs
122
125
  ]
123
126
  self.compile_bs = (
124
127
  [
125
128
  bs
126
129
  for bs in self.capture_bs
127
- if bs <= self.model_runner.server_args.max_torch_compile_bs
130
+ if bs <= self.model_runner.server_args.torch_compile_max_bs
128
131
  ]
129
132
  if self.use_torch_compile
130
133
  else []
@@ -59,11 +59,6 @@ from sglang.srt.server_args import ServerArgs
59
59
  from sglang.srt.utils import (
60
60
  enable_show_time_cost,
61
61
  get_available_gpu_memory,
62
- is_attention_free_model,
63
- is_embedding_model,
64
- is_generation_model,
65
- is_multimodal_model,
66
- model_has_inner_state,
67
62
  monkey_patch_vllm_dummy_weight_loader,
68
63
  monkey_patch_vllm_p2p_access_check,
69
64
  )
@@ -93,9 +88,8 @@ class ModelRunner:
93
88
  self.tp_size = tp_size
94
89
  self.dist_port = nccl_port
95
90
  self.server_args = server_args
96
- self.is_multimodal_model = is_multimodal_model(
97
- self.model_config.hf_config.architectures
98
- )
91
+ self.is_generation = model_config.is_generation
92
+ self.is_multimodal = model_config.is_multimodal
99
93
 
100
94
  # Model-specific adjustment
101
95
  if (
@@ -119,12 +113,12 @@ class ModelRunner:
119
113
  self.server_args.ds_heavy_channel_type
120
114
  )
121
115
 
122
- if self.is_multimodal_model:
116
+ if self.is_multimodal:
123
117
  logger.warning(
124
118
  "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
125
119
  )
126
120
  server_args.chunked_prefill_size = None
127
- server_args.mem_fraction_static *= 0.95
121
+ self.mem_fraction_static *= 0.95
128
122
  # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
129
123
  if self.model_config.hf_config.architectures == [
130
124
  "Qwen2VLForConditionalGeneration"
@@ -270,9 +264,6 @@ class ModelRunner:
270
264
  if hasattr(self.model, "get_attention_sliding_window_size")
271
265
  else None
272
266
  )
273
- self.is_generation = is_generation_model(
274
- self.model_config.hf_config.architectures, self.server_args.is_embedding
275
- )
276
267
 
277
268
  logger.info(
278
269
  f"Load weight end. "
@@ -679,7 +670,7 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
679
670
 
680
671
  # Monkey patch model loader
681
672
  setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
682
- setattr(ModelRegistry, "is_multimodal_model", is_multimodal_model)
683
- setattr(ModelRegistry, "is_attention_free_model", is_attention_free_model)
684
- setattr(ModelRegistry, "model_has_inner_state", model_has_inner_state)
685
- setattr(ModelRegistry, "is_embedding_model", is_embedding_model)
673
+ setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
674
+ setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
675
+ setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
676
+ setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)
@@ -34,10 +34,6 @@ from vllm.model_executor.layers.linear import (
34
34
  RowParallelLinear,
35
35
  )
36
36
  from vllm.model_executor.layers.rotary_embedding import get_rope
37
- from vllm.model_executor.layers.vocab_parallel_embedding import (
38
- ParallelLMHead,
39
- VocabParallelEmbedding,
40
- )
41
37
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
42
38
 
43
39
  from sglang.srt.layers.activation import SiluAndMul
@@ -45,6 +41,10 @@ from sglang.srt.layers.layernorm import RMSNorm
45
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
46
42
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
47
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
+ from sglang.srt.layers.vocab_parallel_embedding import (
45
+ ParallelLMHead,
46
+ VocabParallelEmbedding,
47
+ )
48
48
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
49
49
 
50
50
 
@@ -24,10 +24,6 @@ from torch import nn
24
24
  from torch.nn import LayerNorm
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
27
- from vllm.model_executor.layers.vocab_parallel_embedding import (
28
- ParallelLMHead,
29
- VocabParallelEmbedding,
30
- )
31
27
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32
28
  from vllm.transformers_utils.configs import ChatGLMConfig
33
29
 
@@ -41,6 +37,10 @@ from sglang.srt.layers.linear import (
41
37
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
38
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
43
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
+ from sglang.srt.layers.vocab_parallel_embedding import (
41
+ ParallelLMHead,
42
+ VocabParallelEmbedding,
43
+ )
44
44
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
45
45
 
46
46
  LoraConfig = None
@@ -50,7 +50,6 @@ from vllm.distributed import (
50
50
  get_tensor_model_parallel_world_size,
51
51
  )
52
52
  from vllm.model_executor.layers.rotary_embedding import get_rope
53
- from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
54
53
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
55
54
 
56
55
  from sglang.srt.layers.activation import SiluAndMul
@@ -62,6 +61,7 @@ from sglang.srt.layers.linear import (
62
61
  from sglang.srt.layers.logits_processor import LogitsProcessor
63
62
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
64
63
  from sglang.srt.layers.radix_attention import RadixAttention
64
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
65
65
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
66
66
  from sglang.srt.utils import set_weight_attrs
67
67
 
sglang/srt/models/dbrx.py CHANGED
@@ -27,11 +27,6 @@ from vllm.distributed import (
27
27
  )
28
28
  from vllm.model_executor.layers.fused_moe import fused_moe
29
29
  from vllm.model_executor.layers.rotary_embedding import get_rope
30
- from vllm.model_executor.layers.vocab_parallel_embedding import (
31
- DEFAULT_VOCAB_PADDING_SIZE,
32
- ParallelLMHead,
33
- VocabParallelEmbedding,
34
- )
35
30
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
36
31
  from vllm.transformers_utils.configs.dbrx import DbrxConfig
37
32
 
@@ -43,6 +38,11 @@ from sglang.srt.layers.linear import (
43
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
44
39
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
45
40
  from sglang.srt.layers.radix_attention import RadixAttention
41
+ from sglang.srt.layers.vocab_parallel_embedding import (
42
+ DEFAULT_VOCAB_PADDING_SIZE,
43
+ ParallelLMHead,
44
+ VocabParallelEmbedding,
45
+ )
46
46
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
47
47
  from sglang.srt.utils import set_weight_attrs
48
48
 
@@ -28,10 +28,6 @@ from vllm.distributed import (
28
28
  )
29
29
  from vllm.model_executor.layers.fused_moe import fused_moe
30
30
  from vllm.model_executor.layers.rotary_embedding import get_rope
31
- from vllm.model_executor.layers.vocab_parallel_embedding import (
32
- ParallelLMHead,
33
- VocabParallelEmbedding,
34
- )
35
31
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
36
32
 
37
33
  from sglang.srt.layers.activation import SiluAndMul
@@ -45,6 +41,10 @@ from sglang.srt.layers.linear import (
45
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
46
42
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
47
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
+ from sglang.srt.layers.vocab_parallel_embedding import (
45
+ ParallelLMHead,
46
+ VocabParallelEmbedding,
47
+ )
48
48
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
49
49
 
50
50
 
@@ -27,10 +27,6 @@ from vllm.distributed import (
27
27
  )
28
28
  from vllm.model_executor.layers.fused_moe import FusedMoE
29
29
  from vllm.model_executor.layers.rotary_embedding import get_rope
30
- from vllm.model_executor.layers.vocab_parallel_embedding import (
31
- ParallelLMHead,
32
- VocabParallelEmbedding,
33
- )
34
30
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
31
 
36
32
  from sglang.srt.layers.activation import SiluAndMul
@@ -44,6 +40,10 @@ from sglang.srt.layers.linear import (
44
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
45
41
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
46
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
+ from sglang.srt.layers.vocab_parallel_embedding import (
44
+ ParallelLMHead,
45
+ VocabParallelEmbedding,
46
+ )
47
47
  from sglang.srt.managers.schedule_batch import global_server_args_dict
48
48
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
49
49
  from sglang.srt.utils import is_flashinfer_available
@@ -23,10 +23,6 @@ import torch
23
23
  from torch import nn
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
25
  from vllm.model_executor.layers.rotary_embedding import get_rope
26
- from vllm.model_executor.layers.vocab_parallel_embedding import (
27
- ParallelLMHead,
28
- VocabParallelEmbedding,
29
- )
30
26
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
27
 
32
28
  from sglang.srt.layers.activation import SiluAndMul
@@ -39,6 +35,10 @@ from sglang.srt.layers.linear import (
39
35
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
40
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
41
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
+ from sglang.srt.layers.vocab_parallel_embedding import (
39
+ ParallelLMHead,
40
+ VocabParallelEmbedding,
41
+ )
42
42
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
43
 
44
44
 
@@ -24,7 +24,6 @@ from transformers import PretrainedConfig
24
24
  from vllm.config import LoRAConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
27
- from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
28
27
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
29
28
 
30
29
  from sglang.srt.layers.activation import GeluAndMul
@@ -37,6 +36,7 @@ from sglang.srt.layers.linear import (
37
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
38
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
38
  from sglang.srt.layers.radix_attention import RadixAttention
39
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
40
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
41
41
 
42
42
 
@@ -24,7 +24,6 @@ from vllm.config import LoRAConfig
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
25
 
26
26
  # from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
27
- from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
28
27
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
29
28
 
30
29
  from sglang.srt.layers.activation import GeluAndMul
@@ -37,6 +36,7 @@ from sglang.srt.layers.linear import (
37
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
38
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
38
  from sglang.srt.layers.radix_attention import RadixAttention
39
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
40
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
41
41
 
42
42
 
@@ -0,0 +1,287 @@
1
+ # coding=utf-8
2
+ # Adapted from
3
+ # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
4
+ # Copyright 2023 The vLLM team.
5
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
6
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ """Inference-only GPT-2 model compatible with HuggingFace weights."""
20
+ from typing import Iterable, List, Optional, Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+ from transformers import GPT2Config
25
+ from vllm.config import CacheConfig
26
+ from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
27
+ from vllm.model_executor.layers.activation import get_act_fn
28
+ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
29
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
+
31
+ #from sglang.srt.layers.activation import get_act_fn
32
+ from sglang.srt.layers.linear import (
33
+ ColumnParallelLinear,
34
+ QKVParallelLinear,
35
+ RowParallelLinear,
36
+ )
37
+ from sglang.srt.layers.logits_processor import LogitsProcessor
38
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
+ from sglang.srt.layers.radix_attention import RadixAttention
40
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
41
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
+
43
+
44
+ class GPT2Attention(nn.Module):
45
+
46
+ def __init__(
47
+ self,
48
+ layer_id: int,
49
+ config: GPT2Config,
50
+ cache_config = None,
51
+ quant_config: Optional[QuantizationConfig] = None,
52
+ prefix: str = "",
53
+ ):
54
+ super().__init__()
55
+ self.hidden_size = config.hidden_size
56
+ total_num_heads = config.num_attention_heads
57
+ tensor_model_parallel_world_size = (
58
+ get_tensor_model_parallel_world_size())
59
+ assert total_num_heads % tensor_model_parallel_world_size == 0
60
+ self.num_heads = total_num_heads // tensor_model_parallel_world_size
61
+ self.head_dim = self.hidden_size // total_num_heads
62
+ self.scale = self.head_dim**-0.5
63
+
64
+ self.c_attn = QKVParallelLinear(
65
+ self.hidden_size,
66
+ self.head_dim,
67
+ total_num_heads,
68
+ bias=True,
69
+ quant_config=quant_config,
70
+ prefix=f"{prefix}.c_attn",
71
+ )
72
+ self.c_proj = RowParallelLinear(
73
+ self.hidden_size,
74
+ self.hidden_size,
75
+ bias=True,
76
+ quant_config=quant_config,
77
+ prefix=f"{prefix}.c_proj",
78
+ )
79
+ self.attn = RadixAttention(self.num_heads,
80
+ self.head_dim,
81
+ scaling=self.scale,
82
+ num_kv_heads=total_num_heads,
83
+ layer_id=layer_id)
84
+
85
+ def forward(
86
+ self,
87
+ hidden_states: torch.Tensor,
88
+ forward_batch: ForwardBatch,
89
+ ) -> torch.Tensor:
90
+ qkv, _ = self.c_attn(hidden_states)
91
+ q, k, v = qkv.chunk(chunks=3, dim=-1)
92
+ attn_output = self.attn(q, k, v, forward_batch)
93
+ attn_output, _ = self.c_proj(attn_output)
94
+ return attn_output
95
+
96
+
97
+ class GPT2MLP(nn.Module):
98
+
99
+ def __init__(
100
+ self,
101
+ intermediate_size: int,
102
+ config: GPT2Config,
103
+ quant_config: Optional[QuantizationConfig] = None,
104
+ prefix: str = "",
105
+ ):
106
+ super().__init__()
107
+ hidden_size = config.hidden_size
108
+ self.c_fc = ColumnParallelLinear(
109
+ hidden_size,
110
+ intermediate_size,
111
+ bias=True,
112
+ quant_config=quant_config,
113
+ prefix=f"{prefix}.c_fc",
114
+ )
115
+ self.c_proj = RowParallelLinear(
116
+ intermediate_size,
117
+ hidden_size,
118
+ bias=True,
119
+ quant_config=quant_config,
120
+ prefix=f"{prefix}.c_proj",
121
+ )
122
+ self.act = get_act_fn(config.activation_function, quant_config,
123
+ intermediate_size)
124
+
125
+ def forward(self, hidden_states: torch.Tensor,) -> torch.Tensor:
126
+ hidden_states, _ = self.c_fc(hidden_states)
127
+ hidden_states = self.act(hidden_states)
128
+ hidden_states, _ = self.c_proj(hidden_states)
129
+ return hidden_states
130
+
131
+
132
+ class GPT2Block(nn.Module):
133
+
134
+ def __init__(
135
+ self,
136
+ layer_id: int,
137
+ config: GPT2Config,
138
+ cache_config = None,
139
+
140
+ quant_config: Optional[QuantizationConfig] = None,
141
+ prefix: str = "",
142
+ ):
143
+ super().__init__()
144
+ hidden_size = config.hidden_size
145
+ inner_dim = (config.n_inner if config.n_inner is not None else 4 *
146
+ hidden_size)
147
+
148
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
149
+ self.attn = GPT2Attention(layer_id,
150
+ config,
151
+ cache_config,
152
+ quant_config,
153
+ prefix=f"{prefix}.attn")
154
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
155
+ self.mlp = GPT2MLP(inner_dim,
156
+ config,
157
+ quant_config,
158
+ prefix=f"{prefix}.mlp")
159
+
160
+ def forward(
161
+ self,
162
+ hidden_states: torch.Tensor,
163
+ forward_batch: ForwardBatch,
164
+ ) -> torch.Tensor:
165
+ residual = hidden_states
166
+ hidden_states = self.ln_1(hidden_states)
167
+ attn_output = self.attn(
168
+ hidden_states=hidden_states,
169
+ forward_batch=forward_batch,
170
+ )
171
+ # residual connection
172
+ hidden_states = attn_output + residual
173
+
174
+ residual = hidden_states
175
+ hidden_states = self.ln_2(hidden_states)
176
+ feed_forward_hidden_states = self.mlp(hidden_states)
177
+ # residual connection
178
+ hidden_states = residual + feed_forward_hidden_states
179
+ return hidden_states
180
+
181
+
182
+
183
+ class GPT2Model(nn.Module):
184
+
185
+ def __init__(
186
+ self,
187
+ config: GPT2Config,
188
+ cache_config = None,
189
+ quant_config: Optional[QuantizationConfig] = None,
190
+ prefix: str = "",
191
+ ):
192
+ super().__init__()
193
+ self.config = config
194
+ assert not config.add_cross_attention
195
+ assert not config.scale_attn_by_inverse_layer_idx
196
+ assert not config.reorder_and_upcast_attn
197
+ self.embed_dim = config.hidden_size
198
+ self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
199
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
200
+ self.h = nn.ModuleList(
201
+ [
202
+ GPT2Block(i, config, cache_config, quant_config)
203
+ for i in range(config.num_hidden_layers)
204
+ ]
205
+ )
206
+
207
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
208
+
209
+ def forward(
210
+ self,
211
+ input_ids: torch.Tensor,
212
+ position_ids: torch.Tensor,
213
+ forward_batch: ForwardBatch,
214
+ ) -> torch.Tensor:
215
+ inputs_embeds = self.wte(input_ids)
216
+ position_embeds = self.wpe(position_ids)
217
+ hidden_states = inputs_embeds + position_embeds
218
+
219
+ for i in range(len(self.h)):
220
+ layer = self.h[i]
221
+ hidden_states = layer(hidden_states, forward_batch)
222
+
223
+ hidden_states = self.ln_f(hidden_states)
224
+ return hidden_states
225
+
226
+
227
+ class GPT2LMHeadModel(nn.Module):
228
+
229
+ def __init__(
230
+ self,
231
+ config: GPT2Config,
232
+ cache_config = None,
233
+ quant_config: Optional[QuantizationConfig] = None,
234
+ ):
235
+ super().__init__()
236
+ self.config = config
237
+ self.quant_config = quant_config
238
+ self.transformer = GPT2Model(config,
239
+ cache_config,
240
+ quant_config,
241
+ prefix="transformer")
242
+ self.lm_head = self.transformer.wte
243
+
244
+ self.logits_processor = LogitsProcessor(config)
245
+
246
+ def forward(
247
+ self,
248
+ input_ids: torch.Tensor,
249
+ positions: torch.Tensor,
250
+ forward_batch: ForwardBatch,
251
+ ) -> torch.Tensor:
252
+ hidden_states = self.transformer(input_ids, positions, forward_batch)
253
+ return self.logits_processor(
254
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
255
+ )
256
+
257
+
258
+
259
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
260
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
261
+ for name, loaded_weight in weights:
262
+ if "lm_head.weight" in name:
263
+ # GPT-2 ties the weights of the embedding layer and the final
264
+ # linear layer.
265
+ continue
266
+ if ".attn.bias" in name or ".attn.masked_bias" in name:
267
+ # Skip attention mask.
268
+ # NOTE: "c_attn.bias" should not be skipped.
269
+ continue
270
+ if not name.startswith("transformer."):
271
+ name = "transformer." + name
272
+
273
+ param = params_dict[name]
274
+ # The HF's GPT-2 implementation uses Conv1D instead of Linear.
275
+ # Because of this, we need to transpose the weights.
276
+ # Note(zhuohan): the logic below might break quantized models.
277
+ for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
278
+ if conv1d_weight_name not in name:
279
+ continue
280
+ if not name.endswith(".weight"):
281
+ continue
282
+ loaded_weight = loaded_weight.t()
283
+ weight_loader = getattr(param, "weight_loader",
284
+ default_weight_loader)
285
+ weight_loader(param, loaded_weight)
286
+
287
+ EntryClass = GPT2LMHeadModel
@@ -23,7 +23,6 @@ from torch import nn
23
23
  from transformers import GPTBigCodeConfig
24
24
  from vllm.config import LoRAConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
- from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
27
26
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28
27
 
29
28
  from sglang.srt.layers.activation import get_act_fn
@@ -35,6 +34,7 @@ from sglang.srt.layers.linear import (
35
34
  from sglang.srt.layers.logits_processor import LogitsProcessor
36
35
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
37
36
  from sglang.srt.layers.radix_attention import RadixAttention
37
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
38
38
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
39
39
 
40
40
 
sglang/srt/models/grok.py CHANGED
@@ -28,10 +28,6 @@ from vllm.distributed import (
28
28
  get_tensor_model_parallel_world_size,
29
29
  )
30
30
  from vllm.model_executor.layers.rotary_embedding import get_rope
31
- from vllm.model_executor.layers.vocab_parallel_embedding import (
32
- ParallelLMHead,
33
- VocabParallelEmbedding,
34
- )
35
31
  from vllm.model_executor.model_loader.loader import DefaultModelLoader
36
32
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
37
33
 
@@ -45,6 +41,10 @@ from sglang.srt.layers.linear import (
45
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
46
42
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
47
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
+ from sglang.srt.layers.vocab_parallel_embedding import (
45
+ ParallelLMHead,
46
+ VocabParallelEmbedding,
47
+ )
48
48
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
49
49
 
50
50
 
@@ -23,10 +23,6 @@ from torch import nn
23
23
  from transformers import PretrainedConfig
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
25
  from vllm.model_executor.layers.rotary_embedding import get_rope
26
- from vllm.model_executor.layers.vocab_parallel_embedding import (
27
- ParallelLMHead,
28
- VocabParallelEmbedding,
29
- )
30
26
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
27
 
32
28
  from sglang.srt.layers.activation import SiluAndMul
@@ -39,6 +35,10 @@ from sglang.srt.layers.linear import (
39
35
  from sglang.srt.layers.logits_processor import LogitsProcessor
40
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
41
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
+ from sglang.srt.layers.vocab_parallel_embedding import (
39
+ ParallelLMHead,
40
+ VocabParallelEmbedding,
41
+ )
42
42
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
43
 
44
44