sglang 0.3.0__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_latency.py +17 -8
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +5 -17
  4. sglang/lang/backend/runtime_endpoint.py +5 -2
  5. sglang/lang/interpreter.py +1 -4
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +33 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +1 -3
  14. sglang/srt/layers/activation.py +12 -0
  15. sglang/srt/layers/attention_backend.py +480 -0
  16. sglang/srt/layers/flashinfer_utils.py +235 -0
  17. sglang/srt/layers/fused_moe/layer.py +27 -7
  18. sglang/srt/layers/layernorm.py +12 -0
  19. sglang/srt/layers/logits_processor.py +64 -77
  20. sglang/srt/layers/radix_attention.py +11 -161
  21. sglang/srt/layers/sampler.py +38 -122
  22. sglang/srt/layers/torchao_utils.py +75 -0
  23. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  24. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  25. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  26. sglang/srt/lora/lora.py +403 -0
  27. sglang/srt/lora/lora_config.py +43 -0
  28. sglang/srt/lora/lora_manager.py +259 -0
  29. sglang/srt/managers/controller_multi.py +1 -5
  30. sglang/srt/managers/controller_single.py +0 -5
  31. sglang/srt/managers/io_struct.py +16 -1
  32. sglang/srt/managers/policy_scheduler.py +122 -5
  33. sglang/srt/managers/schedule_batch.py +105 -71
  34. sglang/srt/managers/tokenizer_manager.py +17 -8
  35. sglang/srt/managers/tp_worker.py +188 -121
  36. sglang/srt/model_executor/cuda_graph_runner.py +69 -133
  37. sglang/srt/model_executor/forward_batch_info.py +35 -312
  38. sglang/srt/model_executor/model_runner.py +123 -154
  39. sglang/srt/models/baichuan.py +416 -0
  40. sglang/srt/models/chatglm.py +1 -5
  41. sglang/srt/models/commandr.py +1 -5
  42. sglang/srt/models/dbrx.py +1 -5
  43. sglang/srt/models/deepseek.py +1 -5
  44. sglang/srt/models/deepseek_v2.py +7 -6
  45. sglang/srt/models/exaone.py +1 -5
  46. sglang/srt/models/gemma.py +1 -5
  47. sglang/srt/models/gemma2.py +1 -5
  48. sglang/srt/models/gpt_bigcode.py +1 -5
  49. sglang/srt/models/grok.py +1 -5
  50. sglang/srt/models/internlm2.py +1 -5
  51. sglang/srt/models/llama.py +51 -5
  52. sglang/srt/models/llama_classification.py +1 -20
  53. sglang/srt/models/llava.py +30 -5
  54. sglang/srt/models/llavavid.py +2 -2
  55. sglang/srt/models/minicpm.py +1 -5
  56. sglang/srt/models/minicpm3.py +669 -0
  57. sglang/srt/models/mixtral.py +6 -5
  58. sglang/srt/models/mixtral_quant.py +1 -5
  59. sglang/srt/models/olmoe.py +415 -0
  60. sglang/srt/models/qwen.py +1 -5
  61. sglang/srt/models/qwen2.py +1 -5
  62. sglang/srt/models/qwen2_moe.py +6 -5
  63. sglang/srt/models/stablelm.py +1 -5
  64. sglang/srt/models/xverse.py +375 -0
  65. sglang/srt/models/xverse_moe.py +445 -0
  66. sglang/srt/openai_api/adapter.py +65 -46
  67. sglang/srt/openai_api/protocol.py +11 -3
  68. sglang/srt/sampling/sampling_batch_info.py +46 -80
  69. sglang/srt/server.py +30 -15
  70. sglang/srt/server_args.py +163 -28
  71. sglang/srt/utils.py +19 -51
  72. sglang/test/few_shot_gsm8k.py +132 -0
  73. sglang/test/runners.py +114 -22
  74. sglang/test/test_programs.py +7 -5
  75. sglang/test/test_utils.py +85 -2
  76. sglang/utils.py +32 -37
  77. sglang/version.py +1 -1
  78. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/METADATA +30 -18
  79. sglang-0.3.1.post1.dist-info/RECORD +130 -0
  80. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/WHEEL +1 -1
  81. sglang-0.3.0.dist-info/RECORD +0 -118
  82. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/LICENSE +0 -0
  83. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,375 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ # Adapted from
17
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/xverse.py#L1
18
+ """Inference-only XVERSE model compatible with HuggingFace weights."""
19
+
20
+ from typing import Any, Dict, Iterable, Optional, Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+ from transformers import LlamaConfig
25
+ from vllm.config import CacheConfig
26
+ from vllm.distributed import get_tensor_model_parallel_world_size
27
+ from vllm.model_executor.layers.activation import SiluAndMul
28
+ from vllm.model_executor.layers.layernorm import RMSNorm
29
+ from vllm.model_executor.layers.linear import (
30
+ MergedColumnParallelLinear,
31
+ QKVParallelLinear,
32
+ RowParallelLinear,
33
+ )
34
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
35
+ from vllm.model_executor.layers.rotary_embedding import get_rope
36
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
37
+ ParallelLMHead,
38
+ VocabParallelEmbedding,
39
+ )
40
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
+
42
+ from sglang.srt.layers.logits_processor import LogitsProcessor
43
+ from sglang.srt.layers.radix_attention import RadixAttention
44
+ from sglang.srt.model_executor.model_runner import InputMetadata
45
+
46
+
47
+ class XverseMLP(nn.Module):
48
+ def __init__(
49
+ self,
50
+ hidden_size: int,
51
+ intermediate_size: int,
52
+ hidden_act: str,
53
+ quant_config: Optional[QuantizationConfig] = None,
54
+ prefix: str = "",
55
+ ) -> None:
56
+ super().__init__()
57
+ self.gate_up_proj = MergedColumnParallelLinear(
58
+ hidden_size,
59
+ [intermediate_size] * 2,
60
+ bias=False,
61
+ quant_config=quant_config,
62
+ prefix=f"{prefix}.gate_up_proj",
63
+ )
64
+ self.down_proj = RowParallelLinear(
65
+ intermediate_size,
66
+ hidden_size,
67
+ bias=False,
68
+ quant_config=quant_config,
69
+ prefix=f"{prefix}.down_proj",
70
+ )
71
+ if hidden_act != "silu":
72
+ raise ValueError(
73
+ f"Unsupported activation: {hidden_act}. "
74
+ "Only silu is supported for now."
75
+ )
76
+ self.act_fn = SiluAndMul()
77
+
78
+ def forward(self, x):
79
+ gate_up, _ = self.gate_up_proj(x)
80
+ x = self.act_fn(gate_up)
81
+ x, _ = self.down_proj(x)
82
+ return x
83
+
84
+
85
+ class XverseAttention(nn.Module):
86
+ def __init__(
87
+ self,
88
+ config: LlamaConfig,
89
+ hidden_size: int,
90
+ num_heads: int,
91
+ num_kv_heads: int,
92
+ layer_id: int = 0,
93
+ rope_theta: float = 10000,
94
+ rope_scaling: Optional[Dict[str, Any]] = None,
95
+ rope_is_neox_style: bool = True,
96
+ max_position_embeddings: int = 8192,
97
+ quant_config: Optional[QuantizationConfig] = None,
98
+ prefix: str = "",
99
+ ) -> None:
100
+ super().__init__()
101
+ self.hidden_size = hidden_size
102
+ tp_size = get_tensor_model_parallel_world_size()
103
+ self.total_num_heads = num_heads
104
+ assert self.total_num_heads % tp_size == 0
105
+ self.num_heads = self.total_num_heads // tp_size
106
+ self.total_num_kv_heads = num_kv_heads
107
+ if self.total_num_kv_heads >= tp_size:
108
+ # Number of KV heads is greater than TP size, so we partition
109
+ # the KV heads across multiple tensor parallel GPUs.
110
+ assert self.total_num_kv_heads % tp_size == 0
111
+ else:
112
+ # Number of KV heads is less than TP size, so we replicate
113
+ # the KV heads across multiple tensor parallel GPUs.
114
+ assert tp_size % self.total_num_kv_heads == 0
115
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
116
+ # MistralConfig has an optional head_dim introduced by Mistral-Nemo
117
+ self.head_dim = getattr(
118
+ config, "head_dim", self.hidden_size // self.total_num_heads
119
+ )
120
+ self.q_size = self.num_heads * self.head_dim
121
+ self.kv_size = self.num_kv_heads * self.head_dim
122
+ self.scaling = self.head_dim**-0.5
123
+ self.rope_theta = rope_theta
124
+ self.max_position_embeddings = max_position_embeddings
125
+
126
+ self.qkv_proj = QKVParallelLinear(
127
+ hidden_size,
128
+ self.head_dim,
129
+ self.total_num_heads,
130
+ self.total_num_kv_heads,
131
+ bias=False,
132
+ quant_config=quant_config,
133
+ prefix=f"{prefix}.qkv_proj",
134
+ )
135
+ self.o_proj = RowParallelLinear(
136
+ self.total_num_heads * self.head_dim,
137
+ hidden_size,
138
+ bias=False,
139
+ quant_config=quant_config,
140
+ prefix=f"{prefix}.o_proj",
141
+ )
142
+
143
+ self.rotary_emb = get_rope(
144
+ self.head_dim,
145
+ rotary_dim=self.head_dim,
146
+ max_position=max_position_embeddings,
147
+ base=rope_theta,
148
+ rope_scaling=rope_scaling,
149
+ is_neox_style=rope_is_neox_style,
150
+ )
151
+ self.attn = RadixAttention(
152
+ self.num_heads,
153
+ self.head_dim,
154
+ self.scaling,
155
+ num_kv_heads=self.num_kv_heads,
156
+ layer_id=layer_id,
157
+ )
158
+
159
+ def forward(
160
+ self,
161
+ positions: torch.Tensor,
162
+ hidden_states: torch.Tensor,
163
+ input_metadata: InputMetadata,
164
+ ) -> torch.Tensor:
165
+ qkv, _ = self.qkv_proj(hidden_states)
166
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
167
+ q, k = self.rotary_emb(positions, q, k)
168
+ attn_output = self.attn(q, k, v, input_metadata)
169
+ output, _ = self.o_proj(attn_output)
170
+ return output
171
+
172
+
173
+ class XverseDecoderLayer(nn.Module):
174
+ def __init__(
175
+ self,
176
+ config: LlamaConfig,
177
+ layer_id: int = 0,
178
+ quant_config: Optional[QuantizationConfig] = None,
179
+ prefix: str = "",
180
+ ) -> None:
181
+ super().__init__()
182
+ self.hidden_size = config.hidden_size
183
+ rope_theta = getattr(config, "rope_theta", 10000)
184
+ rope_scaling = getattr(config, "rope_scaling", None)
185
+ if rope_scaling is not None and getattr(
186
+ config, "original_max_position_embeddings", None
187
+ ):
188
+ rope_scaling["original_max_position_embeddings"] = (
189
+ config.original_max_position_embeddings
190
+ )
191
+ rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
192
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
193
+ num_kv_heads = getattr(
194
+ config, "num_key_value_heads", config.num_attention_heads
195
+ )
196
+ self.self_attn = XverseAttention(
197
+ config=config,
198
+ hidden_size=self.hidden_size,
199
+ num_heads=config.num_attention_heads,
200
+ num_kv_heads=num_kv_heads,
201
+ layer_id=layer_id,
202
+ rope_theta=rope_theta,
203
+ rope_scaling=rope_scaling,
204
+ rope_is_neox_style=rope_is_neox_style,
205
+ max_position_embeddings=max_position_embeddings,
206
+ quant_config=quant_config,
207
+ prefix=f"{prefix}.self_attn",
208
+ )
209
+ self.mlp = XverseMLP(
210
+ hidden_size=self.hidden_size,
211
+ intermediate_size=config.intermediate_size,
212
+ hidden_act=config.hidden_act,
213
+ quant_config=quant_config,
214
+ prefix=f"{prefix}.mlp",
215
+ )
216
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
217
+ self.post_attention_layernorm = RMSNorm(
218
+ config.hidden_size, eps=config.rms_norm_eps
219
+ )
220
+
221
+ def forward(
222
+ self,
223
+ positions: torch.Tensor,
224
+ hidden_states: torch.Tensor,
225
+ input_metadata: InputMetadata,
226
+ residual: Optional[torch.Tensor],
227
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
228
+ # Self Attention
229
+ if residual is None:
230
+ residual = hidden_states
231
+ hidden_states = self.input_layernorm(hidden_states)
232
+ else:
233
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
234
+ hidden_states = self.self_attn(
235
+ positions=positions,
236
+ hidden_states=hidden_states,
237
+ input_metadata=input_metadata,
238
+ )
239
+
240
+ # Fully Connected
241
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
242
+ hidden_states = self.mlp(hidden_states)
243
+ return hidden_states, residual
244
+
245
+
246
+ class XverseModel(nn.Module):
247
+ def __init__(
248
+ self,
249
+ config: LlamaConfig,
250
+ quant_config: Optional[QuantizationConfig] = None,
251
+ ) -> None:
252
+ super().__init__()
253
+ self.config = config
254
+ self.padding_idx = config.pad_token_id
255
+ self.vocab_size = config.vocab_size
256
+ self.embed_tokens = VocabParallelEmbedding(
257
+ config.vocab_size,
258
+ config.hidden_size,
259
+ )
260
+ self.layers = nn.ModuleList(
261
+ [
262
+ XverseDecoderLayer(
263
+ config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
264
+ )
265
+ for i in range(config.num_hidden_layers)
266
+ ]
267
+ )
268
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
269
+
270
+ def forward(
271
+ self,
272
+ input_ids: torch.Tensor,
273
+ positions: torch.Tensor,
274
+ input_metadata: InputMetadata,
275
+ input_embeds: torch.Tensor = None,
276
+ ) -> torch.Tensor:
277
+ if input_embeds is None:
278
+ hidden_states = self.embed_tokens(input_ids)
279
+ else:
280
+ hidden_states = input_embeds
281
+ residual = None
282
+ for i in range(len(self.layers)):
283
+ layer = self.layers[i]
284
+ hidden_states, residual = layer(
285
+ positions,
286
+ hidden_states,
287
+ input_metadata,
288
+ residual,
289
+ )
290
+ # print(f"layer[{i}].hidden_states: {hidden_states}")
291
+ hidden_states, _ = self.norm(hidden_states, residual)
292
+ return hidden_states
293
+
294
+
295
+ class XverseForCausalLM(nn.Module):
296
+ def __init__(
297
+ self,
298
+ config: LlamaConfig,
299
+ quant_config: Optional[QuantizationConfig] = None,
300
+ cache_config: Optional[CacheConfig] = None,
301
+ efficient_weight_load=False,
302
+ ) -> None:
303
+ super().__init__()
304
+ self.config = config
305
+ self.quant_config = quant_config
306
+ self.model = XverseModel(config, quant_config=quant_config)
307
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
308
+ self.logits_processor = LogitsProcessor(config)
309
+
310
+ self.param_dict = dict(self.named_parameters())
311
+
312
+ @torch.no_grad()
313
+ def forward(
314
+ self,
315
+ input_ids: torch.Tensor,
316
+ positions: torch.Tensor,
317
+ input_metadata: InputMetadata,
318
+ input_embeds: torch.Tensor = None,
319
+ ) -> torch.Tensor:
320
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
321
+ return self.logits_processor(
322
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
323
+ )
324
+
325
+ def load_weights(
326
+ self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
327
+ ):
328
+ stacked_params_mapping = [
329
+ # (param_name, shard_name, shard_id)
330
+ ("qkv_proj", "q_proj", "q"),
331
+ ("qkv_proj", "k_proj", "k"),
332
+ ("qkv_proj", "v_proj", "v"),
333
+ ("gate_up_proj", "gate_proj", 0),
334
+ ("gate_up_proj", "up_proj", 1),
335
+ ]
336
+ params_dict = self.param_dict
337
+
338
+ def load_weights_per_param(name, loaded_weight):
339
+ if "rotary_emb.inv_freq" in name or "projector" in name:
340
+ return
341
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
342
+ # Models trained using ColossalAI may include these tensors in
343
+ # the checkpoint. Skip them.
344
+ return
345
+ for param_name, weight_name, shard_id in stacked_params_mapping:
346
+ if weight_name not in name:
347
+ continue
348
+ name = name.replace(weight_name, param_name)
349
+ # Skip loading extra bias for GPTQ models.
350
+ if name.endswith(".bias") and name not in params_dict:
351
+ continue
352
+ if name.startswith("model.vision_tower") and name not in params_dict:
353
+ continue
354
+ param = params_dict[name]
355
+ weight_loader = param.weight_loader
356
+ weight_loader(param, loaded_weight, shard_id)
357
+ break
358
+ else:
359
+ # Skip loading extra bias for GPTQ models.
360
+ if name.endswith(".bias") and name not in params_dict:
361
+ return
362
+ if name.startswith("model.vision_tower") and name not in params_dict:
363
+ return
364
+ param = params_dict[name]
365
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
366
+ weight_loader(param, loaded_weight)
367
+
368
+ if name is None or loaded_weight is None:
369
+ for name, loaded_weight in weights:
370
+ load_weights_per_param(name, loaded_weight)
371
+ else:
372
+ load_weights_per_param(name, loaded_weight)
373
+
374
+
375
+ EntryClass = XverseForCausalLM