sglang 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/__init__.py +57 -2
  2. sglang/api.py +8 -5
  3. sglang/backend/anthropic.py +18 -4
  4. sglang/backend/openai.py +2 -1
  5. sglang/backend/runtime_endpoint.py +18 -5
  6. sglang/backend/vertexai.py +1 -0
  7. sglang/global_config.py +5 -1
  8. sglang/lang/chat_template.py +83 -2
  9. sglang/lang/interpreter.py +92 -35
  10. sglang/lang/ir.py +12 -9
  11. sglang/lang/tracer.py +6 -4
  12. sglang/launch_server_llavavid.py +31 -0
  13. sglang/srt/constrained/fsm_cache.py +1 -0
  14. sglang/srt/constrained/jump_forward.py +1 -0
  15. sglang/srt/conversation.py +2 -2
  16. sglang/srt/flush_cache.py +16 -0
  17. sglang/srt/hf_transformers_utils.py +10 -2
  18. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  19. sglang/srt/layers/extend_attention.py +1 -0
  20. sglang/srt/layers/logits_processor.py +114 -54
  21. sglang/srt/layers/radix_attention.py +2 -1
  22. sglang/srt/layers/token_attention.py +1 -0
  23. sglang/srt/managers/detokenizer_manager.py +5 -1
  24. sglang/srt/managers/io_struct.py +27 -3
  25. sglang/srt/managers/router/infer_batch.py +97 -48
  26. sglang/srt/managers/router/manager.py +11 -8
  27. sglang/srt/managers/router/model_rpc.py +169 -90
  28. sglang/srt/managers/router/model_runner.py +110 -166
  29. sglang/srt/managers/router/radix_cache.py +89 -51
  30. sglang/srt/managers/router/scheduler.py +17 -28
  31. sglang/srt/managers/tokenizer_manager.py +110 -33
  32. sglang/srt/memory_pool.py +5 -14
  33. sglang/srt/model_config.py +11 -0
  34. sglang/srt/models/commandr.py +372 -0
  35. sglang/srt/models/dbrx.py +412 -0
  36. sglang/srt/models/dbrx_config.py +281 -0
  37. sglang/srt/models/gemma.py +24 -25
  38. sglang/srt/models/llama2.py +25 -26
  39. sglang/srt/models/llava.py +8 -10
  40. sglang/srt/models/llavavid.py +307 -0
  41. sglang/srt/models/mixtral.py +29 -33
  42. sglang/srt/models/qwen.py +34 -25
  43. sglang/srt/models/qwen2.py +25 -26
  44. sglang/srt/models/stablelm.py +26 -26
  45. sglang/srt/models/yivl.py +3 -5
  46. sglang/srt/openai_api_adapter.py +356 -0
  47. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
  48. sglang/srt/sampling_params.py +2 -0
  49. sglang/srt/server.py +91 -456
  50. sglang/srt/server_args.py +79 -49
  51. sglang/srt/utils.py +212 -47
  52. sglang/srt/weight_utils.py +417 -0
  53. sglang/test/test_programs.py +8 -7
  54. sglang/test/test_utils.py +195 -7
  55. sglang/utils.py +77 -26
  56. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/METADATA +20 -18
  57. sglang-0.1.16.dist-info/RECORD +72 -0
  58. sglang-0.1.14.dist-info/RECORD +0 -64
  59. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
  60. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
  61. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
@@ -1,39 +1,35 @@
1
1
  # Adapted from
2
2
  # https://github.com/vllm-project/vllm/blob/d0215a58e78572d91dadafe9d832a2db89b09a13/vllm/model_executor/models/mixtral.py#L1
3
3
  """Inference-only Mixtral model."""
4
- from typing import List, Optional, Tuple
4
+ from typing import Optional
5
5
 
6
6
  import numpy as np
7
7
  import torch
8
8
  import torch.nn.functional as F
9
- from sglang.srt.layers.logits_processor import LogitsProcessor
10
- from sglang.srt.layers.radix_attention import RadixAttention
11
- from sglang.srt.managers.router.model_runner import InputMetadata
12
9
  from torch import nn
13
10
  from transformers import MixtralConfig
11
+ from vllm.distributed import (
12
+ get_tensor_model_parallel_rank,
13
+ get_tensor_model_parallel_world_size,
14
+ tensor_model_parallel_all_reduce,
15
+ )
14
16
  from vllm.model_executor.layers.layernorm import RMSNorm
15
17
  from vllm.model_executor.layers.linear import (
16
- LinearMethodBase,
17
18
  QKVParallelLinear,
18
19
  ReplicatedLinear,
19
20
  RowParallelLinear,
20
21
  )
22
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
21
23
  from vllm.model_executor.layers.rotary_embedding import get_rope
22
24
  from vllm.model_executor.layers.vocab_parallel_embedding import (
23
25
  ParallelLMHead,
24
26
  VocabParallelEmbedding,
25
27
  )
26
- from vllm.model_executor.parallel_utils.communication_op import (
27
- tensor_model_parallel_all_reduce,
28
- )
29
- from vllm.model_executor.parallel_utils.parallel_state import (
30
- get_tensor_model_parallel_rank,
31
- get_tensor_model_parallel_world_size,
32
- )
33
- from vllm.model_executor.weight_utils import (
34
- default_weight_loader,
35
- hf_model_weights_iterator,
36
- )
28
+
29
+ from sglang.srt.layers.logits_processor import LogitsProcessor
30
+ from sglang.srt.layers.radix_attention import RadixAttention
31
+ from sglang.srt.managers.router.model_runner import InputMetadata
32
+ from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
37
33
 
38
34
 
39
35
  class MixtralMLP(nn.Module):
@@ -42,7 +38,7 @@ class MixtralMLP(nn.Module):
42
38
  num_experts: int,
43
39
  hidden_size: int,
44
40
  intermediate_size: int,
45
- linear_method: Optional[LinearMethodBase] = None,
41
+ quant_config: Optional[QuantizationConfig] = None,
46
42
  ) -> None:
47
43
  super().__init__()
48
44
  self.num_experts = num_experts
@@ -50,13 +46,13 @@ class MixtralMLP(nn.Module):
50
46
  self.hidden_dim = hidden_size
51
47
 
52
48
  self.w1 = ReplicatedLinear(
53
- self.hidden_dim, self.ffn_dim, bias=False, linear_method=linear_method
49
+ self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
54
50
  )
55
51
  self.w2 = ReplicatedLinear(
56
- self.ffn_dim, self.hidden_dim, bias=False, linear_method=linear_method
52
+ self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
57
53
  )
58
54
  self.w3 = ReplicatedLinear(
59
- self.hidden_dim, self.ffn_dim, bias=False, linear_method=linear_method
55
+ self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
60
56
  )
61
57
 
62
58
  # TODO: Use vllm's SiluAndMul
@@ -75,7 +71,7 @@ class MixtralMoE(nn.Module):
75
71
  def __init__(
76
72
  self,
77
73
  config: MixtralConfig,
78
- linear_method: Optional[LinearMethodBase] = None,
74
+ quant_config: Optional[QuantizationConfig] = None,
79
75
  ):
80
76
  super().__init__()
81
77
  self.config = config
@@ -102,7 +98,7 @@ class MixtralMoE(nn.Module):
102
98
  self.num_total_experts,
103
99
  config.hidden_size,
104
100
  config.intermediate_size,
105
- linear_method=linear_method,
101
+ quant_config=quant_config,
106
102
  )
107
103
  if idx in self.expert_indicies
108
104
  else None
@@ -147,7 +143,7 @@ class MixtralAttention(nn.Module):
147
143
  layer_id: int = 0,
148
144
  max_position: int = 4096 * 32,
149
145
  rope_theta: float = 10000,
150
- linear_method: Optional[LinearMethodBase] = None,
146
+ quant_config: Optional[QuantizationConfig] = None,
151
147
  sliding_window: Optional[int] = None,
152
148
  ) -> None:
153
149
  super().__init__()
@@ -179,13 +175,13 @@ class MixtralAttention(nn.Module):
179
175
  self.total_num_heads,
180
176
  self.total_num_kv_heads,
181
177
  bias=False,
182
- linear_method=linear_method,
178
+ quant_config=quant_config,
183
179
  )
184
180
  self.o_proj = RowParallelLinear(
185
181
  self.total_num_heads * self.head_dim,
186
182
  hidden_size,
187
183
  bias=False,
188
- linear_method=linear_method,
184
+ quant_config=quant_config,
189
185
  )
190
186
  self.rotary_emb = get_rope(
191
187
  self.head_dim,
@@ -221,7 +217,7 @@ class MixtralDecoderLayer(nn.Module):
221
217
  self,
222
218
  config: MixtralConfig,
223
219
  layer_id: int = 0,
224
- linear_method: Optional[LinearMethodBase] = None,
220
+ quant_config: Optional[QuantizationConfig] = None,
225
221
  ) -> None:
226
222
  super().__init__()
227
223
  self.hidden_size = config.hidden_size
@@ -235,9 +231,9 @@ class MixtralDecoderLayer(nn.Module):
235
231
  layer_id=layer_id,
236
232
  rope_theta=rope_theta,
237
233
  sliding_window=config.sliding_window,
238
- linear_method=linear_method,
234
+ quant_config=quant_config,
239
235
  )
240
- self.block_sparse_moe = MixtralMoE(config=config, linear_method=linear_method)
236
+ self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config)
241
237
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
242
238
  self.post_attention_layernorm = RMSNorm(
243
239
  config.hidden_size, eps=config.rms_norm_eps
@@ -272,7 +268,7 @@ class MixtralModel(nn.Module):
272
268
  def __init__(
273
269
  self,
274
270
  config: MixtralConfig,
275
- linear_method: Optional[LinearMethodBase] = None,
271
+ quant_config: Optional[QuantizationConfig] = None,
276
272
  ) -> None:
277
273
  super().__init__()
278
274
  self.padding_idx = config.pad_token_id
@@ -285,7 +281,7 @@ class MixtralModel(nn.Module):
285
281
  # config.num_hidden_layers=16
286
282
  self.layers = nn.ModuleList(
287
283
  [
288
- MixtralDecoderLayer(config, i, linear_method=linear_method)
284
+ MixtralDecoderLayer(config, i, quant_config=quant_config)
289
285
  for i in range(config.num_hidden_layers)
290
286
  ]
291
287
  )
@@ -316,12 +312,12 @@ class MixtralForCausalLM(nn.Module):
316
312
  def __init__(
317
313
  self,
318
314
  config: MixtralConfig,
319
- linear_method: Optional[LinearMethodBase] = None,
315
+ quant_config: Optional[QuantizationConfig] = None,
320
316
  ) -> None:
321
317
  super().__init__()
322
318
  self.config = config
323
- self.linear_method = linear_method
324
- self.model = MixtralModel(config, linear_method)
319
+ self.quant_config = quant_config
320
+ self.model = MixtralModel(config, quant_config=quant_config)
325
321
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
326
322
  self.logits_processor = LogitsProcessor(config)
327
323
 
sglang/srt/models/qwen.py CHANGED
@@ -1,31 +1,27 @@
1
- from typing import Any, Dict, List, Optional, Tuple
1
+ from typing import Any, Dict, Optional
2
2
 
3
3
  import torch
4
- from sglang.srt.layers.logits_processor import LogitsProcessor
5
- from sglang.srt.layers.radix_attention import RadixAttention
6
- from sglang.srt.managers.router.model_runner import InputMetadata
7
4
  from torch import nn
8
5
  from transformers import PretrainedConfig
6
+ from vllm.distributed import get_tensor_model_parallel_world_size
9
7
  from vllm.model_executor.layers.activation import SiluAndMul
10
8
  from vllm.model_executor.layers.layernorm import RMSNorm
11
9
  from vllm.model_executor.layers.linear import (
12
- LinearMethodBase,
13
10
  MergedColumnParallelLinear,
14
11
  QKVParallelLinear,
15
12
  RowParallelLinear,
16
13
  )
14
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
17
15
  from vllm.model_executor.layers.rotary_embedding import get_rope
18
16
  from vllm.model_executor.layers.vocab_parallel_embedding import (
19
17
  ParallelLMHead,
20
18
  VocabParallelEmbedding,
21
19
  )
22
- from vllm.model_executor.parallel_utils.parallel_state import (
23
- get_tensor_model_parallel_world_size,
24
- )
25
- from vllm.model_executor.weight_utils import (
26
- default_weight_loader,
27
- hf_model_weights_iterator,
28
- )
20
+
21
+ from sglang.srt.layers.logits_processor import LogitsProcessor
22
+ from sglang.srt.layers.radix_attention import RadixAttention
23
+ from sglang.srt.managers.router.model_runner import InputMetadata
24
+ from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
29
25
 
30
26
 
31
27
  class QWenMLP(nn.Module):
@@ -34,7 +30,7 @@ class QWenMLP(nn.Module):
34
30
  hidden_size: int,
35
31
  intermediate_size: int,
36
32
  hidden_act: str = "silu",
37
- linear_method: Optional[LinearMethodBase] = None,
33
+ quant_config: Optional[QuantizationConfig] = None,
38
34
  ):
39
35
  super().__init__()
40
36
  self.gate_up_proj = MergedColumnParallelLinear(
@@ -42,14 +38,14 @@ class QWenMLP(nn.Module):
42
38
  2 * [intermediate_size],
43
39
  bias=False,
44
40
  gather_output=False,
45
- linear_method=linear_method,
41
+ quant_config=quant_config,
46
42
  )
47
43
  self.c_proj = RowParallelLinear(
48
44
  intermediate_size,
49
45
  hidden_size,
50
46
  bias=False,
51
47
  input_is_parallel=True,
52
- linear_method=linear_method,
48
+ quant_config=quant_config,
53
49
  )
54
50
  if hidden_act != "silu":
55
51
  raise ValueError(
@@ -74,7 +70,7 @@ class QWenAttention(nn.Module):
74
70
  layer_id: int = 0,
75
71
  rope_theta: float = 10000,
76
72
  rope_scaling: Optional[Dict[str, Any]] = None,
77
- linear_method: Optional[LinearMethodBase] = None,
73
+ quant_config: Optional[QuantizationConfig] = None,
78
74
  ):
79
75
  super().__init__()
80
76
  self.hidden_size = hidden_size
@@ -90,14 +86,14 @@ class QWenAttention(nn.Module):
90
86
  self.head_dim,
91
87
  self.total_num_heads,
92
88
  bias=True,
93
- linear_method=linear_method,
89
+ quant_config=quant_config,
94
90
  )
95
91
  self.c_proj = RowParallelLinear(
96
92
  self.total_num_heads * self.head_dim,
97
93
  hidden_size,
98
94
  bias=False,
99
95
  input_is_parallel=True,
100
- linear_method=linear_method,
96
+ quant_config=quant_config,
101
97
  )
102
98
  self.rotary_emb = get_rope(
103
99
  self.head_dim,
@@ -130,7 +126,12 @@ class QWenAttention(nn.Module):
130
126
 
131
127
 
132
128
  class QWenBlock(nn.Module):
133
- def __init__(self, config: PretrainedConfig, layer_id, linear_method=None):
129
+ def __init__(
130
+ self,
131
+ config: PretrainedConfig,
132
+ layer_id,
133
+ quant_config: Optional[QuantizationConfig] = None,
134
+ ):
134
135
  super().__init__()
135
136
  self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
136
137
 
@@ -143,7 +144,7 @@ class QWenBlock(nn.Module):
143
144
  rope_theta=rope_theta,
144
145
  rope_scaling=rope_scaling,
145
146
  layer_id=layer_id,
146
- linear_method=linear_method,
147
+ quant_config=quant_config,
147
148
  )
148
149
 
149
150
  self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@@ -151,7 +152,7 @@ class QWenBlock(nn.Module):
151
152
  self.mlp = QWenMLP(
152
153
  config.hidden_size,
153
154
  config.intermediate_size // 2,
154
- linear_method=linear_method,
155
+ quant_config=quant_config,
155
156
  )
156
157
 
157
158
  def forward(
@@ -179,7 +180,11 @@ class QWenBlock(nn.Module):
179
180
 
180
181
 
181
182
  class QWenModel(nn.Module):
182
- def __init__(self, config: PretrainedConfig, linear_method=None):
183
+ def __init__(
184
+ self,
185
+ config: PretrainedConfig,
186
+ quant_config: Optional[QuantizationConfig] = None,
187
+ ):
183
188
  super().__init__()
184
189
  self.config = config
185
190
  self.vocab_size = config.vocab_size
@@ -191,7 +196,7 @@ class QWenModel(nn.Module):
191
196
  )
192
197
  self.h = nn.ModuleList(
193
198
  [
194
- QWenBlock(config, i, linear_method=linear_method)
199
+ QWenBlock(config, i, quant_config=quant_config)
195
200
  for i in range(config.num_hidden_layers)
196
201
  ]
197
202
  )
@@ -216,10 +221,14 @@ class QWenModel(nn.Module):
216
221
 
217
222
 
218
223
  class QWenLMHeadModel(nn.Module):
219
- def __init__(self, config: PretrainedConfig, linear_method=None):
224
+ def __init__(
225
+ self,
226
+ config: PretrainedConfig,
227
+ quant_config: Optional[QuantizationConfig] = None,
228
+ ):
220
229
  super().__init__()
221
230
  self.config = config
222
- self.transformer = QWenModel(config, linear_method=linear_method)
231
+ self.transformer = QWenModel(config, quant_config=quant_config)
223
232
  vocab_size = ((config.vocab_size + 63) // 64) * 64
224
233
  self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
225
234
  self.logits_processor = LogitsProcessor(config)
@@ -1,33 +1,29 @@
1
1
  # Adapted from llama2.py
2
2
  # Modify details for the adaptation of Qwen2 model.
3
3
  """Inference-only Qwen2 model compatible with HuggingFace weights."""
4
- from typing import Any, Dict, List, Optional, Tuple
4
+ from typing import Any, Dict, Optional, Tuple
5
5
 
6
6
  import torch
7
- from sglang.srt.layers.logits_processor import LogitsProcessor
8
- from sglang.srt.layers.radix_attention import RadixAttention
9
- from sglang.srt.managers.router.model_runner import InputMetadata
10
7
  from torch import nn
8
+ from vllm.distributed import get_tensor_model_parallel_world_size
11
9
  from vllm.model_executor.layers.activation import SiluAndMul
12
10
  from vllm.model_executor.layers.layernorm import RMSNorm
13
11
  from vllm.model_executor.layers.linear import (
14
- LinearMethodBase,
15
12
  MergedColumnParallelLinear,
16
13
  QKVParallelLinear,
17
14
  RowParallelLinear,
18
15
  )
16
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
19
17
  from vllm.model_executor.layers.rotary_embedding import get_rope
20
18
  from vllm.model_executor.layers.vocab_parallel_embedding import (
21
19
  ParallelLMHead,
22
20
  VocabParallelEmbedding,
23
21
  )
24
- from vllm.model_executor.parallel_utils.parallel_state import (
25
- get_tensor_model_parallel_world_size,
26
- )
27
- from vllm.model_executor.weight_utils import (
28
- default_weight_loader,
29
- hf_model_weights_iterator,
30
- )
22
+
23
+ from sglang.srt.layers.logits_processor import LogitsProcessor
24
+ from sglang.srt.layers.radix_attention import RadixAttention
25
+ from sglang.srt.managers.router.model_runner import InputMetadata
26
+ from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
31
27
 
32
28
  Qwen2Config = None
33
29
 
@@ -38,17 +34,20 @@ class Qwen2MLP(nn.Module):
38
34
  hidden_size: int,
39
35
  intermediate_size: int,
40
36
  hidden_act: str,
41
- linear_method: Optional[LinearMethodBase] = None,
37
+ quant_config: Optional[QuantizationConfig] = None,
42
38
  ) -> None:
43
39
  super().__init__()
44
40
  self.gate_up_proj = MergedColumnParallelLinear(
45
41
  hidden_size,
46
42
  [intermediate_size] * 2,
47
43
  bias=False,
48
- linear_method=linear_method,
44
+ quant_config=quant_config,
49
45
  )
50
46
  self.down_proj = RowParallelLinear(
51
- intermediate_size, hidden_size, bias=False, linear_method=linear_method
47
+ intermediate_size,
48
+ hidden_size,
49
+ bias=False,
50
+ quant_config=quant_config,
52
51
  )
53
52
  if hidden_act != "silu":
54
53
  raise ValueError(
@@ -74,7 +73,7 @@ class Qwen2Attention(nn.Module):
74
73
  rope_theta: float = 1000000,
75
74
  rope_scaling: Optional[Dict[str, Any]] = None,
76
75
  max_position_embeddings: int = 32768,
77
- linear_method: Optional[LinearMethodBase] = None,
76
+ quant_config: Optional[QuantizationConfig] = None,
78
77
  ) -> None:
79
78
  super().__init__()
80
79
  self.hidden_size = hidden_size
@@ -105,13 +104,13 @@ class Qwen2Attention(nn.Module):
105
104
  self.total_num_heads,
106
105
  self.total_num_kv_heads,
107
106
  bias=True,
108
- linear_method=linear_method,
107
+ quant_config=quant_config,
109
108
  )
110
109
  self.o_proj = RowParallelLinear(
111
110
  self.total_num_heads * self.head_dim,
112
111
  hidden_size,
113
112
  bias=False,
114
- linear_method=linear_method,
113
+ quant_config=quant_config,
115
114
  )
116
115
 
117
116
  self.rotary_emb = get_rope(
@@ -148,7 +147,7 @@ class Qwen2DecoderLayer(nn.Module):
148
147
  self,
149
148
  config: Qwen2Config,
150
149
  layer_id: int = 0,
151
- linear_method: Optional[LinearMethodBase] = None,
150
+ quant_config: Optional[QuantizationConfig] = None,
152
151
  ) -> None:
153
152
  super().__init__()
154
153
  self.hidden_size = config.hidden_size
@@ -163,13 +162,13 @@ class Qwen2DecoderLayer(nn.Module):
163
162
  rope_theta=rope_theta,
164
163
  rope_scaling=rope_scaling,
165
164
  max_position_embeddings=max_position_embeddings,
166
- linear_method=linear_method,
165
+ quant_config=quant_config,
167
166
  )
168
167
  self.mlp = Qwen2MLP(
169
168
  hidden_size=self.hidden_size,
170
169
  intermediate_size=config.intermediate_size,
171
170
  hidden_act=config.hidden_act,
172
- linear_method=linear_method,
171
+ quant_config=quant_config,
173
172
  )
174
173
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
175
174
  self.post_attention_layernorm = RMSNorm(
@@ -205,7 +204,7 @@ class Qwen2Model(nn.Module):
205
204
  def __init__(
206
205
  self,
207
206
  config: Qwen2Config,
208
- linear_method: Optional[LinearMethodBase] = None,
207
+ quant_config: Optional[QuantizationConfig] = None,
209
208
  ) -> None:
210
209
  super().__init__()
211
210
  self.config = config
@@ -217,7 +216,7 @@ class Qwen2Model(nn.Module):
217
216
  )
218
217
  self.layers = nn.ModuleList(
219
218
  [
220
- Qwen2DecoderLayer(config, i, linear_method)
219
+ Qwen2DecoderLayer(config, i, quant_config=quant_config)
221
220
  for i in range(config.num_hidden_layers)
222
221
  ]
223
222
  )
@@ -251,12 +250,12 @@ class Qwen2ForCausalLM(nn.Module):
251
250
  def __init__(
252
251
  self,
253
252
  config: Qwen2Config,
254
- linear_method: Optional[LinearMethodBase] = None,
253
+ quant_config: Optional[QuantizationConfig] = None,
255
254
  ) -> None:
256
255
  super().__init__()
257
256
  self.config = config
258
- self.linear_method = linear_method
259
- self.model = Qwen2Model(config, linear_method)
257
+ self.quant_config = quant_config
258
+ self.model = Qwen2Model(config, quant_config=quant_config)
260
259
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
261
260
  self.logits_processor = LogitsProcessor(config)
262
261
 
@@ -7,34 +7,31 @@ from typing import Optional, Tuple
7
7
  import torch
8
8
  from torch import nn
9
9
  from transformers import PretrainedConfig
10
-
11
- from sglang.srt.layers.logits_processor import LogitsProcessor
12
- from sglang.srt.layers.radix_attention import RadixAttention
13
- from sglang.srt.managers.router.model_runner import InputMetadata
10
+ from vllm.distributed import get_tensor_model_parallel_world_size
14
11
  from vllm.model_executor.layers.activation import SiluAndMul
15
12
  from vllm.model_executor.layers.linear import (
16
- LinearMethodBase,
17
13
  MergedColumnParallelLinear,
18
14
  QKVParallelLinear,
19
15
  RowParallelLinear,
20
16
  )
17
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
21
18
  from vllm.model_executor.layers.rotary_embedding import get_rope
22
19
  from vllm.model_executor.layers.vocab_parallel_embedding import (
23
- VocabParallelEmbedding,
24
20
  ParallelLMHead,
21
+ VocabParallelEmbedding,
25
22
  )
26
- from vllm.model_executor.parallel_utils.parallel_state import (
27
- get_tensor_model_parallel_world_size,
28
- )
29
- from vllm.model_executor.weight_utils import (
30
- default_weight_loader,
31
- hf_model_weights_iterator,
32
- )
23
+
24
+ from sglang.srt.layers.logits_processor import LogitsProcessor
25
+ from sglang.srt.layers.radix_attention import RadixAttention
26
+ from sglang.srt.managers.router.model_runner import InputMetadata
27
+ from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
33
28
 
34
29
 
35
30
  class StablelmMLP(nn.Module):
36
31
  def __init__(
37
- self, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None
32
+ self,
33
+ config: PretrainedConfig,
34
+ quant_config: Optional[QuantizationConfig] = None,
38
35
  ) -> None:
39
36
  super().__init__()
40
37
  self.config = config
@@ -44,10 +41,13 @@ class StablelmMLP(nn.Module):
44
41
  config.hidden_size,
45
42
  [config.intermediate_size] * 2,
46
43
  bias=False,
47
- linear_method=linear_method,
44
+ quant_config=quant_config,
48
45
  )
49
46
  self.down_proj = RowParallelLinear(
50
- config.intermediate_size, config.hidden_size, bias=False
47
+ config.intermediate_size,
48
+ config.hidden_size,
49
+ bias=False,
50
+ quant_config=quant_config,
51
51
  )
52
52
  self.act_fn = SiluAndMul()
53
53
 
@@ -63,7 +63,7 @@ class StablelmAttention(nn.Module):
63
63
  self,
64
64
  config: PretrainedConfig,
65
65
  layer_id: int = 0,
66
- linear_method: Optional[LinearMethodBase] = None,
66
+ quant_config: Optional[QuantizationConfig] = None,
67
67
  ) -> None:
68
68
  super().__init__()
69
69
  self.config = config
@@ -105,13 +105,11 @@ class StablelmAttention(nn.Module):
105
105
  self.total_num_heads,
106
106
  self.total_num_key_value_heads,
107
107
  self.qkv_bias,
108
- linear_method=linear_method,
109
108
  )
110
109
  self.o_proj = RowParallelLinear(
111
110
  self.total_num_heads * self.head_dim,
112
111
  self.hidden_size,
113
112
  bias=False,
114
- linear_method=linear_method,
115
113
  )
116
114
  self.rotary_emb = get_rope(
117
115
  self.head_dim,
@@ -146,11 +144,11 @@ class StablelmDecoderLayer(nn.Module):
146
144
  self,
147
145
  config: PretrainedConfig,
148
146
  layer_id: int = 0,
149
- linear_method: Optional[LinearMethodBase] = None,
147
+ quant_config: Optional[QuantizationConfig] = None,
150
148
  ) -> None:
151
149
  super().__init__()
152
150
  self.self_attn = StablelmAttention(config, layer_id=layer_id)
153
- self.mlp = StablelmMLP(config, linear_method)
151
+ self.mlp = StablelmMLP(config, quant_config=quant_config)
154
152
  norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05))
155
153
  self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
156
154
  self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
@@ -182,7 +180,9 @@ class StablelmDecoderLayer(nn.Module):
182
180
 
183
181
  class StableLMEpochModel(nn.Module):
184
182
  def __init__(
185
- self, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None
183
+ self,
184
+ config: PretrainedConfig,
185
+ quant_config: Optional[QuantizationConfig] = None,
186
186
  ) -> None:
187
187
  super().__init__()
188
188
  self.embed_tokens = VocabParallelEmbedding(
@@ -191,7 +191,7 @@ class StableLMEpochModel(nn.Module):
191
191
  )
192
192
  self.layers = nn.ModuleList(
193
193
  [
194
- StablelmDecoderLayer(config, i, linear_method)
194
+ StablelmDecoderLayer(config, i, quant_config=quant_config)
195
195
  for i in range(config.num_hidden_layers)
196
196
  ]
197
197
  )
@@ -224,12 +224,12 @@ class StableLmForCausalLM(nn.Module):
224
224
  def __init__(
225
225
  self,
226
226
  config: PretrainedConfig,
227
- linear_method: Optional[LinearMethodBase] = None,
227
+ quant_config: Optional[QuantizationConfig] = None,
228
228
  ) -> None:
229
229
  super().__init__()
230
230
  self.config = config
231
- self.linear_method = linear_method
232
- self.model = StableLMEpochModel(config, linear_method)
231
+ self.quant_config = quant_config
232
+ self.model = StableLMEpochModel(config, quant_config=quant_config)
233
233
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
234
234
  self.logits_processor = LogitsProcessor(config)
235
235
 
sglang/srt/models/yivl.py CHANGED
@@ -5,16 +5,14 @@ from typing import List, Optional
5
5
 
6
6
  import torch
7
7
  import torch.nn as nn
8
+ from transformers import CLIPVisionModel, LlavaConfig
9
+
8
10
  from sglang.srt.models.llava import (
9
11
  LlavaLlamaForCausalLM,
10
12
  clip_vision_embed_forward,
11
13
  monkey_path_clip_vision_embed_forward,
12
14
  )
13
- from transformers import CLIPVisionModel, LlavaConfig
14
- from vllm.model_executor.weight_utils import (
15
- default_weight_loader,
16
- hf_model_weights_iterator,
17
- )
15
+ from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
18
16
 
19
17
 
20
18
  class YiVLForCausalLM(LlavaLlamaForCausalLM):