sglang 0.1.14__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. sglang/__init__.py +55 -2
  2. sglang/api.py +3 -5
  3. sglang/backend/anthropic.py +18 -4
  4. sglang/backend/openai.py +2 -1
  5. sglang/backend/runtime_endpoint.py +18 -5
  6. sglang/backend/vertexai.py +1 -0
  7. sglang/global_config.py +1 -0
  8. sglang/lang/chat_template.py +74 -0
  9. sglang/lang/interpreter.py +40 -16
  10. sglang/lang/tracer.py +6 -4
  11. sglang/launch_server.py +2 -1
  12. sglang/srt/constrained/fsm_cache.py +1 -0
  13. sglang/srt/constrained/jump_forward.py +1 -0
  14. sglang/srt/conversation.py +2 -2
  15. sglang/srt/hf_transformers_utils.py +2 -1
  16. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  17. sglang/srt/layers/extend_attention.py +1 -0
  18. sglang/srt/layers/logits_processor.py +114 -54
  19. sglang/srt/layers/radix_attention.py +2 -1
  20. sglang/srt/layers/token_attention.py +1 -0
  21. sglang/srt/managers/detokenizer_manager.py +5 -1
  22. sglang/srt/managers/io_struct.py +12 -0
  23. sglang/srt/managers/router/infer_batch.py +70 -33
  24. sglang/srt/managers/router/manager.py +7 -2
  25. sglang/srt/managers/router/model_rpc.py +116 -73
  26. sglang/srt/managers/router/model_runner.py +111 -167
  27. sglang/srt/managers/router/radix_cache.py +46 -38
  28. sglang/srt/managers/tokenizer_manager.py +56 -11
  29. sglang/srt/memory_pool.py +5 -14
  30. sglang/srt/model_config.py +7 -0
  31. sglang/srt/models/commandr.py +376 -0
  32. sglang/srt/models/dbrx.py +413 -0
  33. sglang/srt/models/dbrx_config.py +281 -0
  34. sglang/srt/models/gemma.py +22 -20
  35. sglang/srt/models/llama2.py +23 -21
  36. sglang/srt/models/llava.py +12 -10
  37. sglang/srt/models/mixtral.py +27 -25
  38. sglang/srt/models/qwen.py +23 -21
  39. sglang/srt/models/qwen2.py +23 -21
  40. sglang/srt/models/stablelm.py +20 -21
  41. sglang/srt/models/yivl.py +6 -5
  42. sglang/srt/openai_api_adapter.py +356 -0
  43. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
  44. sglang/srt/sampling_params.py +2 -0
  45. sglang/srt/server.py +68 -447
  46. sglang/srt/server_args.py +76 -49
  47. sglang/srt/utils.py +88 -32
  48. sglang/srt/weight_utils.py +402 -0
  49. sglang/test/test_programs.py +8 -7
  50. sglang/test/test_utils.py +195 -7
  51. {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/METADATA +12 -14
  52. sglang-0.1.15.dist-info/RECORD +69 -0
  53. sglang-0.1.14.dist-info/RECORD +0 -64
  54. {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/LICENSE +0 -0
  55. {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/WHEEL +0 -0
  56. {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/top_level.txt +0 -0
@@ -4,47 +4,49 @@
4
4
  from typing import Optional, Tuple
5
5
 
6
6
  import torch
7
- from sglang.srt.layers.logits_processor import LogitsProcessor
8
- from sglang.srt.layers.radix_attention import RadixAttention
9
7
  from torch import nn
10
8
  from transformers import PretrainedConfig
11
9
  from vllm.config import LoRAConfig
12
- from vllm.model_executor.input_metadata import InputMetadata
13
10
  from vllm.model_executor.layers.activation import GeluAndMul
14
11
  from vllm.model_executor.layers.layernorm import RMSNorm
15
12
  from vllm.model_executor.layers.linear import (
16
- LinearMethodBase,
17
13
  MergedColumnParallelLinear,
18
14
  QKVParallelLinear,
19
15
  RowParallelLinear,
20
16
  )
17
+ from vllm.model_executor.layers.quantization.base_config import (
18
+ QuantizationConfig)
21
19
  from vllm.model_executor.layers.rotary_embedding import get_rope
22
20
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
23
- from vllm.model_executor.parallel_utils.parallel_state import (
21
+ from vllm.distributed import (
24
22
  get_tensor_model_parallel_world_size,
25
23
  )
26
- from vllm.model_executor.weight_utils import (
24
+ from sglang.srt.weight_utils import (
27
25
  default_weight_loader,
28
26
  hf_model_weights_iterator,
29
27
  )
30
28
 
29
+ from sglang.srt.layers.logits_processor import LogitsProcessor
30
+ from sglang.srt.layers.radix_attention import RadixAttention
31
+ from sglang.srt.managers.router.model_runner import InputMetadata
32
+
31
33
 
32
34
  class GemmaMLP(nn.Module):
33
35
  def __init__(
34
36
  self,
35
37
  hidden_size: int,
36
38
  intermediate_size: int,
37
- linear_method: Optional[LinearMethodBase] = None,
39
+ quant_config: Optional[QuantizationConfig] = None,
38
40
  ) -> None:
39
41
  super().__init__()
40
42
  self.gate_up_proj = MergedColumnParallelLinear(
41
43
  hidden_size,
42
44
  [intermediate_size] * 2,
43
45
  bias=False,
44
- linear_method=linear_method,
46
+ quant_config=quant_config,
45
47
  )
46
48
  self.down_proj = RowParallelLinear(
47
- intermediate_size, hidden_size, bias=False, linear_method=linear_method
49
+ intermediate_size, hidden_size, bias=False, quant_config=quant_config,
48
50
  )
49
51
  self.act_fn = GeluAndMul()
50
52
 
@@ -65,7 +67,7 @@ class GemmaAttention(nn.Module):
65
67
  layer_id: int = 0,
66
68
  max_position_embeddings: int = 8192,
67
69
  rope_theta: float = 10000,
68
- linear_method: Optional[LinearMethodBase] = None,
70
+ quant_config: Optional[QuantizationConfig] = None,
69
71
  ) -> None:
70
72
  super().__init__()
71
73
  self.hidden_size = hidden_size
@@ -95,13 +97,13 @@ class GemmaAttention(nn.Module):
95
97
  self.total_num_heads,
96
98
  self.total_num_kv_heads,
97
99
  bias=False,
98
- linear_method=linear_method,
100
+ quant_config=quant_config,
99
101
  )
100
102
  self.o_proj = RowParallelLinear(
101
103
  self.total_num_heads * self.head_dim,
102
104
  hidden_size,
103
105
  bias=False,
104
- linear_method=linear_method,
106
+ quant_config=quant_config,
105
107
  )
106
108
 
107
109
  self.rotary_emb = get_rope(
@@ -138,7 +140,7 @@ class GemmaDecoderLayer(nn.Module):
138
140
  self,
139
141
  config: PretrainedConfig,
140
142
  layer_id: int = 0,
141
- linear_method: Optional[LinearMethodBase] = None,
143
+ quant_config: Optional[QuantizationConfig] = None,
142
144
  ) -> None:
143
145
  super().__init__()
144
146
  self.hidden_size = config.hidden_size
@@ -150,12 +152,12 @@ class GemmaDecoderLayer(nn.Module):
150
152
  layer_id=layer_id,
151
153
  max_position_embeddings=config.max_position_embeddings,
152
154
  rope_theta=config.rope_theta,
153
- linear_method=linear_method,
155
+ quant_config=quant_config,
154
156
  )
155
157
  self.mlp = GemmaMLP(
156
158
  hidden_size=self.hidden_size,
157
159
  intermediate_size=config.intermediate_size,
158
- linear_method=linear_method,
160
+ quant_config=quant_config,
159
161
  )
160
162
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
161
163
  self.post_attention_layernorm = RMSNorm(
@@ -191,7 +193,7 @@ class GemmaModel(nn.Module):
191
193
  def __init__(
192
194
  self,
193
195
  config: PretrainedConfig,
194
- linear_method: Optional[LinearMethodBase] = None,
196
+ quant_config: Optional[QuantizationConfig] = None,
195
197
  ) -> None:
196
198
  super().__init__()
197
199
  self.config = config
@@ -202,7 +204,7 @@ class GemmaModel(nn.Module):
202
204
  )
203
205
  self.layers = nn.ModuleList(
204
206
  [
205
- GemmaDecoderLayer(config, i, linear_method)
207
+ GemmaDecoderLayer(config, i, quant_config=quant_config)
206
208
  for i in range(config.num_hidden_layers)
207
209
  ]
208
210
  )
@@ -263,14 +265,14 @@ class GemmaForCausalLM(nn.Module):
263
265
  def __init__(
264
266
  self,
265
267
  config: PretrainedConfig,
266
- linear_method: Optional[LinearMethodBase] = None,
268
+ quant_config: Optional[QuantizationConfig] = None,
267
269
  lora_config: Optional[LoRAConfig] = None,
268
270
  ) -> None:
269
271
  del lora_config # Unused.
270
272
  super().__init__()
271
273
  self.config = config
272
- self.linear_method = linear_method
273
- self.model = GemmaModel(config, linear_method)
274
+ self.quant_config = quant_config
275
+ self.model = GemmaModel(config, quant_config=quant_config)
274
276
  self.logits_processor = LogitsProcessor(config)
275
277
 
276
278
  @torch.no_grad()
@@ -1,35 +1,37 @@
1
1
  # Adapted from
2
2
  # https://github.com/vllm-project/vllm/blob/671af2b1c0b3ed6d856d37c21a561cc429a10701/vllm/model_executor/models/llama.py#L1
3
3
  """Inference-only LLaMA model compatible with HuggingFace weights."""
4
- from typing import Any, Dict, List, Optional, Tuple
4
+ from typing import Any, Dict, Optional, Tuple
5
5
 
6
6
  import torch
7
- from sglang.srt.layers.logits_processor import LogitsProcessor
8
- from sglang.srt.layers.radix_attention import RadixAttention
9
- from sglang.srt.managers.router.model_runner import InputMetadata
10
7
  from torch import nn
11
8
  from transformers import LlamaConfig
12
9
  from vllm.model_executor.layers.activation import SiluAndMul
13
10
  from vllm.model_executor.layers.layernorm import RMSNorm
14
11
  from vllm.model_executor.layers.linear import (
15
- LinearMethodBase,
16
12
  MergedColumnParallelLinear,
17
13
  QKVParallelLinear,
18
14
  RowParallelLinear,
19
15
  )
16
+ from vllm.model_executor.layers.quantization.base_config import (
17
+ QuantizationConfig)
20
18
  from vllm.model_executor.layers.rotary_embedding import get_rope
21
19
  from vllm.model_executor.layers.vocab_parallel_embedding import (
22
20
  ParallelLMHead,
23
21
  VocabParallelEmbedding,
24
22
  )
25
- from vllm.model_executor.parallel_utils.parallel_state import (
23
+ from vllm.distributed import (
26
24
  get_tensor_model_parallel_world_size,
27
25
  )
28
- from vllm.model_executor.weight_utils import (
26
+ from sglang.srt.weight_utils import (
29
27
  default_weight_loader,
30
28
  hf_model_weights_iterator,
31
29
  )
32
30
 
31
+ from sglang.srt.layers.logits_processor import LogitsProcessor
32
+ from sglang.srt.layers.radix_attention import RadixAttention
33
+ from sglang.srt.managers.router.model_runner import InputMetadata
34
+
33
35
 
34
36
  class LlamaMLP(nn.Module):
35
37
  def __init__(
@@ -37,17 +39,17 @@ class LlamaMLP(nn.Module):
37
39
  hidden_size: int,
38
40
  intermediate_size: int,
39
41
  hidden_act: str,
40
- linear_method: Optional[LinearMethodBase] = None,
42
+ quant_config: Optional[QuantizationConfig] = None,
41
43
  ) -> None:
42
44
  super().__init__()
43
45
  self.gate_up_proj = MergedColumnParallelLinear(
44
46
  hidden_size,
45
47
  [intermediate_size] * 2,
46
48
  bias=False,
47
- linear_method=linear_method,
49
+ quant_config=quant_config,
48
50
  )
49
51
  self.down_proj = RowParallelLinear(
50
- intermediate_size, hidden_size, bias=False, linear_method=linear_method
52
+ intermediate_size, hidden_size, bias=False, quant_config=quant_config,
51
53
  )
52
54
  if hidden_act != "silu":
53
55
  raise ValueError(
@@ -73,7 +75,7 @@ class LlamaAttention(nn.Module):
73
75
  rope_theta: float = 10000,
74
76
  rope_scaling: Optional[Dict[str, Any]] = None,
75
77
  max_position_embeddings: int = 8192,
76
- linear_method: Optional[LinearMethodBase] = None,
78
+ quant_config: Optional[QuantizationConfig] = None,
77
79
  ) -> None:
78
80
  super().__init__()
79
81
  self.hidden_size = hidden_size
@@ -104,13 +106,13 @@ class LlamaAttention(nn.Module):
104
106
  self.total_num_heads,
105
107
  self.total_num_kv_heads,
106
108
  bias=False,
107
- linear_method=linear_method,
109
+ quant_config=quant_config,
108
110
  )
109
111
  self.o_proj = RowParallelLinear(
110
112
  self.total_num_heads * self.head_dim,
111
113
  hidden_size,
112
114
  bias=False,
113
- linear_method=linear_method,
115
+ quant_config=quant_config,
114
116
  )
115
117
 
116
118
  self.rotary_emb = get_rope(
@@ -147,7 +149,7 @@ class LlamaDecoderLayer(nn.Module):
147
149
  self,
148
150
  config: LlamaConfig,
149
151
  layer_id: int = 0,
150
- linear_method: Optional[LinearMethodBase] = None,
152
+ quant_config: Optional[QuantizationConfig] = None,
151
153
  ) -> None:
152
154
  super().__init__()
153
155
  self.hidden_size = config.hidden_size
@@ -162,13 +164,13 @@ class LlamaDecoderLayer(nn.Module):
162
164
  rope_theta=rope_theta,
163
165
  rope_scaling=rope_scaling,
164
166
  max_position_embeddings=max_position_embeddings,
165
- linear_method=linear_method,
167
+ quant_config=quant_config,
166
168
  )
167
169
  self.mlp = LlamaMLP(
168
170
  hidden_size=self.hidden_size,
169
171
  intermediate_size=config.intermediate_size,
170
172
  hidden_act=config.hidden_act,
171
- linear_method=linear_method,
173
+ quant_config=quant_config,
172
174
  )
173
175
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
174
176
  self.post_attention_layernorm = RMSNorm(
@@ -204,7 +206,7 @@ class LlamaModel(nn.Module):
204
206
  def __init__(
205
207
  self,
206
208
  config: LlamaConfig,
207
- linear_method: Optional[LinearMethodBase] = None,
209
+ quant_config: Optional[QuantizationConfig] = None,
208
210
  ) -> None:
209
211
  super().__init__()
210
212
  self.config = config
@@ -216,7 +218,7 @@ class LlamaModel(nn.Module):
216
218
  )
217
219
  self.layers = nn.ModuleList(
218
220
  [
219
- LlamaDecoderLayer(config, i, linear_method)
221
+ LlamaDecoderLayer(config, i, quant_config=quant_config)
220
222
  for i in range(config.num_hidden_layers)
221
223
  ]
222
224
  )
@@ -250,12 +252,12 @@ class LlamaForCausalLM(nn.Module):
250
252
  def __init__(
251
253
  self,
252
254
  config: LlamaConfig,
253
- linear_method: Optional[LinearMethodBase] = None,
255
+ quant_config: Optional[QuantizationConfig] = None,
254
256
  ) -> None:
255
257
  super().__init__()
256
258
  self.config = config
257
- self.linear_method = linear_method
258
- self.model = LlamaModel(config, linear_method)
259
+ self.quant_config = quant_config
260
+ self.model = LlamaModel(config, quant_config=quant_config)
259
261
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
260
262
  self.logits_processor = LogitsProcessor(config)
261
263
 
@@ -4,6 +4,16 @@ from typing import List, Optional
4
4
 
5
5
  import numpy as np
6
6
  import torch
7
+ from torch import nn
8
+ from transformers import CLIPVisionModel, LlavaConfig
9
+ from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
10
+ from vllm.model_executor.layers.quantization.base_config import (
11
+ QuantizationConfig)
12
+ from sglang.srt.weight_utils import (
13
+ default_weight_loader,
14
+ hf_model_weights_iterator,
15
+ )
16
+
7
17
  from sglang.srt.managers.router.infer_batch import ForwardMode
8
18
  from sglang.srt.managers.router.model_runner import InputMetadata
9
19
  from sglang.srt.mm_utils import (
@@ -12,21 +22,13 @@ from sglang.srt.mm_utils import (
12
22
  unpad_image_shape,
13
23
  )
14
24
  from sglang.srt.models.llama2 import LlamaForCausalLM
15
- from torch import nn
16
- from transformers import CLIPVisionModel, LlamaConfig, LlavaConfig
17
- from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
18
- from vllm.model_executor.layers.linear import LinearMethodBase
19
- from vllm.model_executor.weight_utils import (
20
- default_weight_loader,
21
- hf_model_weights_iterator,
22
- )
23
25
 
24
26
 
25
27
  class LlavaLlamaForCausalLM(nn.Module):
26
28
  def __init__(
27
29
  self,
28
30
  config: LlavaConfig,
29
- linear_method: Optional[LinearMethodBase] = None,
31
+ quant_config: Optional[QuantizationConfig] = None,
30
32
  ) -> None:
31
33
  super().__init__()
32
34
  self.config = config
@@ -34,7 +36,7 @@ class LlavaLlamaForCausalLM(nn.Module):
34
36
  self.config.vision_config.hidden_size = config.mm_hidden_size
35
37
  self.config.text_config.hidden_size = config.hidden_size
36
38
  self.multi_modal_projector = LlavaMultiModalProjector(config)
37
- self.language_model = LlamaForCausalLM(config, linear_method)
39
+ self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
38
40
  if "unpad" in getattr(config, "mm_patch_merge_type", ""):
39
41
  self.language_model.model.image_newline = nn.Parameter(
40
42
  torch.empty(config.text_config.hidden_size, dtype=torch.float16)
@@ -1,40 +1,42 @@
1
1
  # Adapted from
2
2
  # https://github.com/vllm-project/vllm/blob/d0215a58e78572d91dadafe9d832a2db89b09a13/vllm/model_executor/models/mixtral.py#L1
3
3
  """Inference-only Mixtral model."""
4
- from typing import List, Optional, Tuple
4
+ from typing import Optional
5
5
 
6
6
  import numpy as np
7
7
  import torch
8
8
  import torch.nn.functional as F
9
- from sglang.srt.layers.logits_processor import LogitsProcessor
10
- from sglang.srt.layers.radix_attention import RadixAttention
11
- from sglang.srt.managers.router.model_runner import InputMetadata
12
9
  from torch import nn
13
10
  from transformers import MixtralConfig
14
11
  from vllm.model_executor.layers.layernorm import RMSNorm
15
12
  from vllm.model_executor.layers.linear import (
16
- LinearMethodBase,
17
13
  QKVParallelLinear,
18
14
  ReplicatedLinear,
19
15
  RowParallelLinear,
20
16
  )
17
+ from vllm.model_executor.layers.quantization.base_config import (
18
+ QuantizationConfig)
21
19
  from vllm.model_executor.layers.rotary_embedding import get_rope
22
20
  from vllm.model_executor.layers.vocab_parallel_embedding import (
23
21
  ParallelLMHead,
24
22
  VocabParallelEmbedding,
25
23
  )
26
- from vllm.model_executor.parallel_utils.communication_op import (
24
+ from vllm.distributed import (
27
25
  tensor_model_parallel_all_reduce,
28
26
  )
29
- from vllm.model_executor.parallel_utils.parallel_state import (
27
+ from vllm.distributed import (
30
28
  get_tensor_model_parallel_rank,
31
29
  get_tensor_model_parallel_world_size,
32
30
  )
33
- from vllm.model_executor.weight_utils import (
31
+ from sglang.srt.weight_utils import (
34
32
  default_weight_loader,
35
33
  hf_model_weights_iterator,
36
34
  )
37
35
 
36
+ from sglang.srt.layers.logits_processor import LogitsProcessor
37
+ from sglang.srt.layers.radix_attention import RadixAttention
38
+ from sglang.srt.managers.router.model_runner import InputMetadata
39
+
38
40
 
39
41
  class MixtralMLP(nn.Module):
40
42
  def __init__(
@@ -42,7 +44,7 @@ class MixtralMLP(nn.Module):
42
44
  num_experts: int,
43
45
  hidden_size: int,
44
46
  intermediate_size: int,
45
- linear_method: Optional[LinearMethodBase] = None,
47
+ quant_config: Optional[QuantizationConfig] = None,
46
48
  ) -> None:
47
49
  super().__init__()
48
50
  self.num_experts = num_experts
@@ -50,13 +52,13 @@ class MixtralMLP(nn.Module):
50
52
  self.hidden_dim = hidden_size
51
53
 
52
54
  self.w1 = ReplicatedLinear(
53
- self.hidden_dim, self.ffn_dim, bias=False, linear_method=linear_method
55
+ self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
54
56
  )
55
57
  self.w2 = ReplicatedLinear(
56
- self.ffn_dim, self.hidden_dim, bias=False, linear_method=linear_method
58
+ self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
57
59
  )
58
60
  self.w3 = ReplicatedLinear(
59
- self.hidden_dim, self.ffn_dim, bias=False, linear_method=linear_method
61
+ self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
60
62
  )
61
63
 
62
64
  # TODO: Use vllm's SiluAndMul
@@ -75,7 +77,7 @@ class MixtralMoE(nn.Module):
75
77
  def __init__(
76
78
  self,
77
79
  config: MixtralConfig,
78
- linear_method: Optional[LinearMethodBase] = None,
80
+ quant_config: Optional[QuantizationConfig] = None,
79
81
  ):
80
82
  super().__init__()
81
83
  self.config = config
@@ -102,7 +104,7 @@ class MixtralMoE(nn.Module):
102
104
  self.num_total_experts,
103
105
  config.hidden_size,
104
106
  config.intermediate_size,
105
- linear_method=linear_method,
107
+ quant_config=quant_config,
106
108
  )
107
109
  if idx in self.expert_indicies
108
110
  else None
@@ -147,7 +149,7 @@ class MixtralAttention(nn.Module):
147
149
  layer_id: int = 0,
148
150
  max_position: int = 4096 * 32,
149
151
  rope_theta: float = 10000,
150
- linear_method: Optional[LinearMethodBase] = None,
152
+ quant_config: Optional[QuantizationConfig] = None,
151
153
  sliding_window: Optional[int] = None,
152
154
  ) -> None:
153
155
  super().__init__()
@@ -179,13 +181,13 @@ class MixtralAttention(nn.Module):
179
181
  self.total_num_heads,
180
182
  self.total_num_kv_heads,
181
183
  bias=False,
182
- linear_method=linear_method,
184
+ quant_config=quant_config,
183
185
  )
184
186
  self.o_proj = RowParallelLinear(
185
187
  self.total_num_heads * self.head_dim,
186
188
  hidden_size,
187
189
  bias=False,
188
- linear_method=linear_method,
190
+ quant_config=quant_config,
189
191
  )
190
192
  self.rotary_emb = get_rope(
191
193
  self.head_dim,
@@ -221,7 +223,7 @@ class MixtralDecoderLayer(nn.Module):
221
223
  self,
222
224
  config: MixtralConfig,
223
225
  layer_id: int = 0,
224
- linear_method: Optional[LinearMethodBase] = None,
226
+ quant_config: Optional[QuantizationConfig] = None,
225
227
  ) -> None:
226
228
  super().__init__()
227
229
  self.hidden_size = config.hidden_size
@@ -235,9 +237,9 @@ class MixtralDecoderLayer(nn.Module):
235
237
  layer_id=layer_id,
236
238
  rope_theta=rope_theta,
237
239
  sliding_window=config.sliding_window,
238
- linear_method=linear_method,
240
+ quant_config=quant_config,
239
241
  )
240
- self.block_sparse_moe = MixtralMoE(config=config, linear_method=linear_method)
242
+ self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config)
241
243
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
242
244
  self.post_attention_layernorm = RMSNorm(
243
245
  config.hidden_size, eps=config.rms_norm_eps
@@ -272,7 +274,7 @@ class MixtralModel(nn.Module):
272
274
  def __init__(
273
275
  self,
274
276
  config: MixtralConfig,
275
- linear_method: Optional[LinearMethodBase] = None,
277
+ quant_config: Optional[QuantizationConfig] = None,
276
278
  ) -> None:
277
279
  super().__init__()
278
280
  self.padding_idx = config.pad_token_id
@@ -285,7 +287,7 @@ class MixtralModel(nn.Module):
285
287
  # config.num_hidden_layers=16
286
288
  self.layers = nn.ModuleList(
287
289
  [
288
- MixtralDecoderLayer(config, i, linear_method=linear_method)
290
+ MixtralDecoderLayer(config, i, quant_config=quant_config)
289
291
  for i in range(config.num_hidden_layers)
290
292
  ]
291
293
  )
@@ -316,12 +318,12 @@ class MixtralForCausalLM(nn.Module):
316
318
  def __init__(
317
319
  self,
318
320
  config: MixtralConfig,
319
- linear_method: Optional[LinearMethodBase] = None,
321
+ quant_config: Optional[QuantizationConfig] = None,
320
322
  ) -> None:
321
323
  super().__init__()
322
324
  self.config = config
323
- self.linear_method = linear_method
324
- self.model = MixtralModel(config, linear_method)
325
+ self.quant_config = quant_config
326
+ self.model = MixtralModel(config, quant_config=quant_config)
325
327
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
326
328
  self.logits_processor = LogitsProcessor(config)
327
329
 
sglang/srt/models/qwen.py CHANGED
@@ -1,32 +1,34 @@
1
- from typing import Any, Dict, List, Optional, Tuple
1
+ from typing import Any, Dict, Optional
2
2
 
3
3
  import torch
4
- from sglang.srt.layers.logits_processor import LogitsProcessor
5
- from sglang.srt.layers.radix_attention import RadixAttention
6
- from sglang.srt.managers.router.model_runner import InputMetadata
7
4
  from torch import nn
8
5
  from transformers import PretrainedConfig
9
6
  from vllm.model_executor.layers.activation import SiluAndMul
10
7
  from vllm.model_executor.layers.layernorm import RMSNorm
11
8
  from vllm.model_executor.layers.linear import (
12
- LinearMethodBase,
13
9
  MergedColumnParallelLinear,
14
10
  QKVParallelLinear,
15
11
  RowParallelLinear,
16
12
  )
13
+ from vllm.model_executor.layers.quantization.base_config import (
14
+ QuantizationConfig)
17
15
  from vllm.model_executor.layers.rotary_embedding import get_rope
18
16
  from vllm.model_executor.layers.vocab_parallel_embedding import (
19
17
  ParallelLMHead,
20
18
  VocabParallelEmbedding,
21
19
  )
22
- from vllm.model_executor.parallel_utils.parallel_state import (
20
+ from vllm.distributed import (
23
21
  get_tensor_model_parallel_world_size,
24
22
  )
25
- from vllm.model_executor.weight_utils import (
23
+ from sglang.srt.weight_utils import (
26
24
  default_weight_loader,
27
25
  hf_model_weights_iterator,
28
26
  )
29
27
 
28
+ from sglang.srt.layers.logits_processor import LogitsProcessor
29
+ from sglang.srt.layers.radix_attention import RadixAttention
30
+ from sglang.srt.managers.router.model_runner import InputMetadata
31
+
30
32
 
31
33
  class QWenMLP(nn.Module):
32
34
  def __init__(
@@ -34,7 +36,7 @@ class QWenMLP(nn.Module):
34
36
  hidden_size: int,
35
37
  intermediate_size: int,
36
38
  hidden_act: str = "silu",
37
- linear_method: Optional[LinearMethodBase] = None,
39
+ quant_config: Optional[QuantizationConfig] = None,
38
40
  ):
39
41
  super().__init__()
40
42
  self.gate_up_proj = MergedColumnParallelLinear(
@@ -42,14 +44,14 @@ class QWenMLP(nn.Module):
42
44
  2 * [intermediate_size],
43
45
  bias=False,
44
46
  gather_output=False,
45
- linear_method=linear_method,
47
+ quant_config=quant_config,
46
48
  )
47
49
  self.c_proj = RowParallelLinear(
48
50
  intermediate_size,
49
51
  hidden_size,
50
52
  bias=False,
51
53
  input_is_parallel=True,
52
- linear_method=linear_method,
54
+ quant_config=quant_config,
53
55
  )
54
56
  if hidden_act != "silu":
55
57
  raise ValueError(
@@ -74,7 +76,7 @@ class QWenAttention(nn.Module):
74
76
  layer_id: int = 0,
75
77
  rope_theta: float = 10000,
76
78
  rope_scaling: Optional[Dict[str, Any]] = None,
77
- linear_method: Optional[LinearMethodBase] = None,
79
+ quant_config: Optional[QuantizationConfig] = None,
78
80
  ):
79
81
  super().__init__()
80
82
  self.hidden_size = hidden_size
@@ -90,14 +92,14 @@ class QWenAttention(nn.Module):
90
92
  self.head_dim,
91
93
  self.total_num_heads,
92
94
  bias=True,
93
- linear_method=linear_method,
95
+ quant_config=quant_config,
94
96
  )
95
97
  self.c_proj = RowParallelLinear(
96
98
  self.total_num_heads * self.head_dim,
97
99
  hidden_size,
98
100
  bias=False,
99
101
  input_is_parallel=True,
100
- linear_method=linear_method,
102
+ quant_config=quant_config,
101
103
  )
102
104
  self.rotary_emb = get_rope(
103
105
  self.head_dim,
@@ -130,7 +132,7 @@ class QWenAttention(nn.Module):
130
132
 
131
133
 
132
134
  class QWenBlock(nn.Module):
133
- def __init__(self, config: PretrainedConfig, layer_id, linear_method=None):
135
+ def __init__(self, config: PretrainedConfig, layer_id, quant_config: Optional[QuantizationConfig] = None,):
134
136
  super().__init__()
135
137
  self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
136
138
 
@@ -143,7 +145,7 @@ class QWenBlock(nn.Module):
143
145
  rope_theta=rope_theta,
144
146
  rope_scaling=rope_scaling,
145
147
  layer_id=layer_id,
146
- linear_method=linear_method,
148
+ quant_config=quant_config,
147
149
  )
148
150
 
149
151
  self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@@ -151,7 +153,7 @@ class QWenBlock(nn.Module):
151
153
  self.mlp = QWenMLP(
152
154
  config.hidden_size,
153
155
  config.intermediate_size // 2,
154
- linear_method=linear_method,
156
+ quant_config=quant_config,
155
157
  )
156
158
 
157
159
  def forward(
@@ -179,7 +181,7 @@ class QWenBlock(nn.Module):
179
181
 
180
182
 
181
183
  class QWenModel(nn.Module):
182
- def __init__(self, config: PretrainedConfig, linear_method=None):
184
+ def __init__(self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,):
183
185
  super().__init__()
184
186
  self.config = config
185
187
  self.vocab_size = config.vocab_size
@@ -191,7 +193,7 @@ class QWenModel(nn.Module):
191
193
  )
192
194
  self.h = nn.ModuleList(
193
195
  [
194
- QWenBlock(config, i, linear_method=linear_method)
196
+ QWenBlock(config, i, quant_config=quant_config)
195
197
  for i in range(config.num_hidden_layers)
196
198
  ]
197
199
  )
@@ -216,10 +218,10 @@ class QWenModel(nn.Module):
216
218
 
217
219
 
218
220
  class QWenLMHeadModel(nn.Module):
219
- def __init__(self, config: PretrainedConfig, linear_method=None):
221
+ def __init__(self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,):
220
222
  super().__init__()
221
223
  self.config = config
222
- self.transformer = QWenModel(config, linear_method=linear_method)
224
+ self.transformer = QWenModel(config, quant_config=quant_config)
223
225
  vocab_size = ((config.vocab_size + 63) // 64) * 64
224
226
  self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
225
227
  self.logits_processor = LogitsProcessor(config)
@@ -274,4 +276,4 @@ class QWenLMHeadModel(nn.Module):
274
276
  weight_loader(param, loaded_weight)
275
277
 
276
278
 
277
- EntryClass = QWenLMHeadModel
279
+ EntryClass = QWenLMHeadModel