sglang 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/api.py +14 -0
  3. sglang/backend/anthropic.py +18 -12
  4. sglang/backend/base_backend.py +6 -0
  5. sglang/backend/openai.py +41 -12
  6. sglang/backend/runtime_endpoint.py +57 -6
  7. sglang/lang/chat_template.py +47 -26
  8. sglang/lang/interpreter.py +15 -2
  9. sglang/lang/ir.py +1 -1
  10. sglang/srt/constrained/__init__.py +23 -1
  11. sglang/srt/constrained/fsm_cache.py +14 -3
  12. sglang/srt/layers/context_flashattention_nopad.py +1 -1
  13. sglang/srt/layers/extend_attention.py +7 -6
  14. sglang/srt/layers/radix_attention.py +2 -10
  15. sglang/srt/layers/token_attention.py +12 -4
  16. sglang/srt/managers/io_struct.py +3 -1
  17. sglang/srt/managers/router/infer_batch.py +6 -2
  18. sglang/srt/managers/router/model_rpc.py +45 -32
  19. sglang/srt/managers/router/model_runner.py +40 -25
  20. sglang/srt/managers/tokenizer_manager.py +2 -0
  21. sglang/srt/model_config.py +12 -5
  22. sglang/srt/models/gemma.py +340 -0
  23. sglang/srt/models/llama2.py +5 -5
  24. sglang/srt/models/llava.py +2 -4
  25. sglang/srt/models/mixtral.py +5 -5
  26. sglang/srt/models/qwen.py +4 -4
  27. sglang/srt/models/qwen2.py +5 -5
  28. sglang/srt/models/stablelm.py +293 -0
  29. sglang/srt/server.py +111 -47
  30. sglang/srt/server_args.py +44 -9
  31. sglang/srt/utils.py +1 -0
  32. sglang/test/test_utils.py +1 -1
  33. sglang/utils.py +15 -12
  34. {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/METADATA +16 -6
  35. sglang-0.1.14.dist-info/RECORD +64 -0
  36. {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/WHEEL +1 -1
  37. sglang/srt/models/gpt_neox.py +0 -274
  38. sglang-0.1.12.dist-info/RECORD +0 -63
  39. {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/LICENSE +0 -0
  40. {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,340 @@
1
+ # Adapted from:
2
+ # https://github.com/vllm-project/vllm/blob/d65fac2738f0287a41955b45df76a2d5a919bff6/vllm/model_executor/models/gemma.py
3
+ """Inference-only Gemma model compatible with HuggingFace weights."""
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ from sglang.srt.layers.logits_processor import LogitsProcessor
8
+ from sglang.srt.layers.radix_attention import RadixAttention
9
+ from torch import nn
10
+ from transformers import PretrainedConfig
11
+ from vllm.config import LoRAConfig
12
+ from vllm.model_executor.input_metadata import InputMetadata
13
+ from vllm.model_executor.layers.activation import GeluAndMul
14
+ from vllm.model_executor.layers.layernorm import RMSNorm
15
+ from vllm.model_executor.layers.linear import (
16
+ LinearMethodBase,
17
+ MergedColumnParallelLinear,
18
+ QKVParallelLinear,
19
+ RowParallelLinear,
20
+ )
21
+ from vllm.model_executor.layers.rotary_embedding import get_rope
22
+ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
23
+ from vllm.model_executor.parallel_utils.parallel_state import (
24
+ get_tensor_model_parallel_world_size,
25
+ )
26
+ from vllm.model_executor.weight_utils import (
27
+ default_weight_loader,
28
+ hf_model_weights_iterator,
29
+ )
30
+
31
+
32
+ class GemmaMLP(nn.Module):
33
+ def __init__(
34
+ self,
35
+ hidden_size: int,
36
+ intermediate_size: int,
37
+ linear_method: Optional[LinearMethodBase] = None,
38
+ ) -> None:
39
+ super().__init__()
40
+ self.gate_up_proj = MergedColumnParallelLinear(
41
+ hidden_size,
42
+ [intermediate_size] * 2,
43
+ bias=False,
44
+ linear_method=linear_method,
45
+ )
46
+ self.down_proj = RowParallelLinear(
47
+ intermediate_size, hidden_size, bias=False, linear_method=linear_method
48
+ )
49
+ self.act_fn = GeluAndMul()
50
+
51
+ def forward(self, x):
52
+ gate_up, _ = self.gate_up_proj(x)
53
+ x = self.act_fn(gate_up)
54
+ x, _ = self.down_proj(x)
55
+ return x
56
+
57
+
58
+ class GemmaAttention(nn.Module):
59
+ def __init__(
60
+ self,
61
+ hidden_size: int,
62
+ num_heads: int,
63
+ num_kv_heads: int,
64
+ head_dim: int,
65
+ layer_id: int = 0,
66
+ max_position_embeddings: int = 8192,
67
+ rope_theta: float = 10000,
68
+ linear_method: Optional[LinearMethodBase] = None,
69
+ ) -> None:
70
+ super().__init__()
71
+ self.hidden_size = hidden_size
72
+ tp_size = get_tensor_model_parallel_world_size()
73
+ self.total_num_heads = num_heads
74
+ assert self.total_num_heads % tp_size == 0
75
+ self.num_heads = self.total_num_heads // tp_size
76
+ self.total_num_kv_heads = num_kv_heads
77
+ if self.total_num_kv_heads >= tp_size:
78
+ # Number of KV heads is greater than TP size, so we partition
79
+ # the KV heads across multiple tensor parallel GPUs.
80
+ assert self.total_num_kv_heads % tp_size == 0
81
+ else:
82
+ # Number of KV heads is less than TP size, so we replicate
83
+ # the KV heads across multiple tensor parallel GPUs.
84
+ assert tp_size % self.total_num_kv_heads == 0
85
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
86
+ self.head_dim = head_dim
87
+ self.q_size = self.num_heads * self.head_dim
88
+ self.kv_size = self.num_kv_heads * self.head_dim
89
+ self.scaling = self.head_dim**-0.5
90
+ self.rope_theta = rope_theta
91
+
92
+ self.qkv_proj = QKVParallelLinear(
93
+ hidden_size,
94
+ self.head_dim,
95
+ self.total_num_heads,
96
+ self.total_num_kv_heads,
97
+ bias=False,
98
+ linear_method=linear_method,
99
+ )
100
+ self.o_proj = RowParallelLinear(
101
+ self.total_num_heads * self.head_dim,
102
+ hidden_size,
103
+ bias=False,
104
+ linear_method=linear_method,
105
+ )
106
+
107
+ self.rotary_emb = get_rope(
108
+ self.head_dim,
109
+ rotary_dim=self.head_dim,
110
+ max_position=max_position_embeddings,
111
+ base=self.rope_theta,
112
+ is_neox_style=True,
113
+ )
114
+ self.attn = RadixAttention(
115
+ self.num_heads,
116
+ self.head_dim,
117
+ self.scaling,
118
+ num_kv_heads=self.num_kv_heads,
119
+ layer_id=layer_id,
120
+ )
121
+
122
+ def forward(
123
+ self,
124
+ positions: torch.Tensor,
125
+ hidden_states: torch.Tensor,
126
+ input_metadata: InputMetadata,
127
+ ) -> torch.Tensor:
128
+ qkv, _ = self.qkv_proj(hidden_states)
129
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
130
+ q, k = self.rotary_emb(positions, q, k)
131
+ attn_output = self.attn(q, k, v, input_metadata)
132
+ output, _ = self.o_proj(attn_output)
133
+ return output
134
+
135
+
136
+ class GemmaDecoderLayer(nn.Module):
137
+ def __init__(
138
+ self,
139
+ config: PretrainedConfig,
140
+ layer_id: int = 0,
141
+ linear_method: Optional[LinearMethodBase] = None,
142
+ ) -> None:
143
+ super().__init__()
144
+ self.hidden_size = config.hidden_size
145
+ self.self_attn = GemmaAttention(
146
+ hidden_size=self.hidden_size,
147
+ num_heads=config.num_attention_heads,
148
+ num_kv_heads=config.num_key_value_heads,
149
+ head_dim=config.head_dim,
150
+ layer_id=layer_id,
151
+ max_position_embeddings=config.max_position_embeddings,
152
+ rope_theta=config.rope_theta,
153
+ linear_method=linear_method,
154
+ )
155
+ self.mlp = GemmaMLP(
156
+ hidden_size=self.hidden_size,
157
+ intermediate_size=config.intermediate_size,
158
+ linear_method=linear_method,
159
+ )
160
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
161
+ self.post_attention_layernorm = RMSNorm(
162
+ config.hidden_size, eps=config.rms_norm_eps
163
+ )
164
+
165
+ def forward(
166
+ self,
167
+ positions: torch.Tensor,
168
+ hidden_states: torch.Tensor,
169
+ input_metadata: InputMetadata,
170
+ residual: Optional[torch.Tensor],
171
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
172
+ # Self Attention
173
+ if residual is None:
174
+ residual = hidden_states
175
+ hidden_states = self.input_layernorm(hidden_states)
176
+ else:
177
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
178
+ hidden_states = self.self_attn(
179
+ positions=positions,
180
+ hidden_states=hidden_states,
181
+ input_metadata=input_metadata,
182
+ )
183
+
184
+ # Fully Connected
185
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
186
+ hidden_states = self.mlp(hidden_states)
187
+ return hidden_states, residual
188
+
189
+
190
+ class GemmaModel(nn.Module):
191
+ def __init__(
192
+ self,
193
+ config: PretrainedConfig,
194
+ linear_method: Optional[LinearMethodBase] = None,
195
+ ) -> None:
196
+ super().__init__()
197
+ self.config = config
198
+
199
+ self.embed_tokens = VocabParallelEmbedding(
200
+ config.vocab_size,
201
+ config.hidden_size,
202
+ )
203
+ self.layers = nn.ModuleList(
204
+ [
205
+ GemmaDecoderLayer(config, i, linear_method)
206
+ for i in range(config.num_hidden_layers)
207
+ ]
208
+ )
209
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
210
+
211
+ def forward(
212
+ self,
213
+ input_ids: torch.Tensor,
214
+ positions: torch.Tensor,
215
+ input_metadata: InputMetadata,
216
+ input_embeds: torch.Tensor = None,
217
+ ) -> torch.Tensor:
218
+ if input_embeds is None:
219
+ hidden_states = self.embed_tokens(input_ids)
220
+ else:
221
+ hidden_states = input_embeds
222
+
223
+ # Normalize the embedding by sqrt(hidden_size)
224
+ hidden_states *= self.config.hidden_size**0.5
225
+
226
+ residual = None
227
+ for i in range(len(self.layers)):
228
+ layer = self.layers[i]
229
+ hidden_states, residual = layer(
230
+ positions,
231
+ hidden_states,
232
+ input_metadata,
233
+ residual,
234
+ )
235
+ hidden_states, _ = self.norm(hidden_states, residual)
236
+ return hidden_states
237
+
238
+
239
+ class GemmaForCausalLM(nn.Module):
240
+ packed_modules_mapping = {
241
+ "qkv_proj": [
242
+ "q_proj",
243
+ "k_proj",
244
+ "v_proj",
245
+ ],
246
+ "gate_up_proj": [
247
+ "gate_proj",
248
+ "up_proj",
249
+ ],
250
+ }
251
+
252
+ # LoRA specific attributes
253
+ supported_lora_modules = [
254
+ "qkv_proj",
255
+ "o_proj",
256
+ "gate_up_proj",
257
+ "down_proj",
258
+ ]
259
+ # Gemma does not apply LoRA to the embedding layer.
260
+ embedding_modules = {}
261
+ embedding_padding_modules = []
262
+
263
+ def __init__(
264
+ self,
265
+ config: PretrainedConfig,
266
+ linear_method: Optional[LinearMethodBase] = None,
267
+ lora_config: Optional[LoRAConfig] = None,
268
+ ) -> None:
269
+ del lora_config # Unused.
270
+ super().__init__()
271
+ self.config = config
272
+ self.linear_method = linear_method
273
+ self.model = GemmaModel(config, linear_method)
274
+ self.logits_processor = LogitsProcessor(config)
275
+
276
+ @torch.no_grad()
277
+ def forward(
278
+ self,
279
+ input_ids: torch.Tensor,
280
+ positions: torch.Tensor,
281
+ input_metadata: InputMetadata,
282
+ input_embeds: torch.Tensor = None,
283
+ ) -> torch.Tensor:
284
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
285
+ return self.logits_processor(
286
+ input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
287
+ )
288
+
289
+ def load_weights(
290
+ self,
291
+ model_name_or_path: str,
292
+ cache_dir: Optional[str] = None,
293
+ load_format: str = "auto",
294
+ revision: Optional[str] = None,
295
+ ):
296
+ stacked_params_mapping = [
297
+ # (param_name, shard_name, shard_id)
298
+ ("qkv_proj", "q_proj", "q"),
299
+ ("qkv_proj", "k_proj", "k"),
300
+ ("qkv_proj", "v_proj", "v"),
301
+ ("gate_up_proj", "gate_proj", 0),
302
+ ("gate_up_proj", "up_proj", 1),
303
+ ]
304
+ params_dict = dict(self.named_parameters())
305
+ loaded_params = set()
306
+ for name, loaded_weight in hf_model_weights_iterator(
307
+ model_name_or_path, cache_dir, load_format, revision
308
+ ):
309
+ for param_name, shard_name, shard_id in stacked_params_mapping:
310
+ if shard_name not in name:
311
+ continue
312
+ name = name.replace(shard_name, param_name)
313
+ # Skip loading extra bias for GPTQ models.
314
+ if name.endswith(".bias") and name not in params_dict:
315
+ continue
316
+ param = params_dict[name]
317
+ weight_loader = param.weight_loader
318
+ weight_loader(param, loaded_weight, shard_id)
319
+ break
320
+ else:
321
+ # Skip loading extra bias for GPTQ models.
322
+ if name.endswith(".bias") and name not in params_dict:
323
+ continue
324
+ # GemmaRMSNorm is different from Llama's in that it multiplies
325
+ # (1 + weight) to the output, instead of just weight.
326
+ if "norm.weight" in name:
327
+ loaded_weight += 1.0
328
+ param = params_dict[name]
329
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
330
+ weight_loader(param, loaded_weight)
331
+ loaded_params.add(name)
332
+ unloaded_params = params_dict.keys() - loaded_params
333
+ if unloaded_params:
334
+ raise RuntimeError(
335
+ "Some weights are not initialized from checkpoints: "
336
+ f"{unloaded_params}"
337
+ )
338
+
339
+
340
+ EntryClass = GemmaForCausalLM
@@ -227,12 +227,12 @@ class LlamaModel(nn.Module):
227
227
  input_ids: torch.Tensor,
228
228
  positions: torch.Tensor,
229
229
  input_metadata: InputMetadata,
230
- skip_embed: bool = False,
230
+ input_embeds: torch.Tensor = None,
231
231
  ) -> torch.Tensor:
232
- if not skip_embed:
232
+ if input_embeds is None:
233
233
  hidden_states = self.embed_tokens(input_ids)
234
234
  else:
235
- hidden_states = input_ids
235
+ hidden_states = input_embeds
236
236
  residual = None
237
237
  for i in range(len(self.layers)):
238
238
  layer = self.layers[i]
@@ -264,9 +264,9 @@ class LlamaForCausalLM(nn.Module):
264
264
  input_ids: torch.Tensor,
265
265
  positions: torch.Tensor,
266
266
  input_metadata: InputMetadata,
267
- skip_embed: bool = False,
267
+ input_embeds: torch.Tensor = None,
268
268
  ) -> torch.Tensor:
269
- hidden_states = self.model(input_ids, positions, input_metadata, skip_embed)
269
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
270
270
  return self.logits_processor(
271
271
  input_ids, hidden_states, self.lm_head.weight, input_metadata
272
272
  )
@@ -230,12 +230,10 @@ class LlavaLlamaForCausalLM(nn.Module):
230
230
  pt += 1
231
231
 
232
232
  return self.language_model(
233
- input_embeds, positions, input_metadata, skip_embed=True
233
+ input_ids, positions, input_metadata, input_embeds=input_embeds
234
234
  )
235
235
  elif input_metadata.forward_mode == ForwardMode.DECODE:
236
- return self.language_model(
237
- input_ids, positions, input_metadata, skip_embed=False
238
- )
236
+ return self.language_model(input_ids, positions, input_metadata)
239
237
 
240
238
  def load_weights(
241
239
  self,
@@ -296,12 +296,12 @@ class MixtralModel(nn.Module):
296
296
  input_ids: torch.Tensor,
297
297
  positions: torch.Tensor,
298
298
  input_metadata: InputMetadata,
299
- skip_embed: bool = False,
299
+ input_embeds: torch.Tensor = None,
300
300
  ) -> torch.Tensor:
301
- if not skip_embed:
301
+ if input_embeds is None:
302
302
  hidden_states = self.embed_tokens(input_ids)
303
303
  else:
304
- hidden_states = input_ids
304
+ hidden_states = input_embeds
305
305
  residual = None
306
306
  for i in range(len(self.layers)):
307
307
  layer = self.layers[i]
@@ -330,9 +330,9 @@ class MixtralForCausalLM(nn.Module):
330
330
  input_ids: torch.Tensor,
331
331
  positions: torch.Tensor,
332
332
  input_metadata: InputMetadata,
333
- skip_embed: bool = False,
333
+ input_embeds: torch.Tensor = None,
334
334
  ) -> torch.Tensor:
335
- hidden_states = self.model(input_ids, positions, input_metadata, skip_embed)
335
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
336
336
  return self.logits_processor(
337
337
  input_ids, hidden_states, self.lm_head.weight, input_metadata
338
338
  )
sglang/srt/models/qwen.py CHANGED
@@ -5,6 +5,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
5
5
  from sglang.srt.layers.radix_attention import RadixAttention
6
6
  from sglang.srt.managers.router.model_runner import InputMetadata
7
7
  from torch import nn
8
+ from transformers import PretrainedConfig
8
9
  from vllm.model_executor.layers.activation import SiluAndMul
9
10
  from vllm.model_executor.layers.layernorm import RMSNorm
10
11
  from vllm.model_executor.layers.linear import (
@@ -25,7 +26,6 @@ from vllm.model_executor.weight_utils import (
25
26
  default_weight_loader,
26
27
  hf_model_weights_iterator,
27
28
  )
28
- from vllm.transformers_utils.configs.qwen import QWenConfig
29
29
 
30
30
 
31
31
  class QWenMLP(nn.Module):
@@ -130,7 +130,7 @@ class QWenAttention(nn.Module):
130
130
 
131
131
 
132
132
  class QWenBlock(nn.Module):
133
- def __init__(self, config: QWenConfig, layer_id, linear_method=None):
133
+ def __init__(self, config: PretrainedConfig, layer_id, linear_method=None):
134
134
  super().__init__()
135
135
  self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
136
136
 
@@ -179,7 +179,7 @@ class QWenBlock(nn.Module):
179
179
 
180
180
 
181
181
  class QWenModel(nn.Module):
182
- def __init__(self, config: QWenConfig, linear_method=None):
182
+ def __init__(self, config: PretrainedConfig, linear_method=None):
183
183
  super().__init__()
184
184
  self.config = config
185
185
  self.vocab_size = config.vocab_size
@@ -216,7 +216,7 @@ class QWenModel(nn.Module):
216
216
 
217
217
 
218
218
  class QWenLMHeadModel(nn.Module):
219
- def __init__(self, config: QWenConfig, linear_method=None):
219
+ def __init__(self, config: PretrainedConfig, linear_method=None):
220
220
  super().__init__()
221
221
  self.config = config
222
222
  self.transformer = QWenModel(config, linear_method=linear_method)
@@ -228,12 +228,12 @@ class Qwen2Model(nn.Module):
228
228
  input_ids: torch.Tensor,
229
229
  positions: torch.Tensor,
230
230
  input_metadata: InputMetadata,
231
- skip_embed: bool = False,
231
+ input_embeds: torch.Tensor = None,
232
232
  ) -> torch.Tensor:
233
- if not skip_embed:
233
+ if input_embeds is None:
234
234
  hidden_states = self.embed_tokens(input_ids)
235
235
  else:
236
- hidden_states = input_ids
236
+ hidden_states = input_embeds
237
237
  residual = None
238
238
  for i in range(len(self.layers)):
239
239
  layer = self.layers[i]
@@ -265,9 +265,9 @@ class Qwen2ForCausalLM(nn.Module):
265
265
  input_ids: torch.Tensor,
266
266
  positions: torch.Tensor,
267
267
  input_metadata: InputMetadata,
268
- skip_embed: bool = False,
268
+ input_embeds: torch.Tensor = None,
269
269
  ) -> torch.Tensor:
270
- hidden_states = self.model(input_ids, positions, input_metadata, skip_embed)
270
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
271
271
  return self.logits_processor(
272
272
  input_ids, hidden_states, self.lm_head.weight, input_metadata
273
273
  )