sglang 0.1.18__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/api.py +26 -0
  3. sglang/backend/runtime_endpoint.py +18 -14
  4. sglang/bench_latency.py +34 -16
  5. sglang/global_config.py +1 -0
  6. sglang/lang/chat_template.py +41 -6
  7. sglang/lang/interpreter.py +5 -1
  8. sglang/lang/ir.py +61 -25
  9. sglang/srt/constrained/__init__.py +3 -2
  10. sglang/srt/hf_transformers_utils.py +7 -3
  11. sglang/srt/layers/extend_attention.py +2 -1
  12. sglang/srt/layers/fused_moe.py +181 -167
  13. sglang/srt/layers/logits_processor.py +55 -19
  14. sglang/srt/layers/radix_attention.py +24 -27
  15. sglang/srt/layers/token_attention.py +4 -1
  16. sglang/srt/managers/controller/infer_batch.py +2 -2
  17. sglang/srt/managers/controller/manager_single.py +1 -1
  18. sglang/srt/managers/controller/model_runner.py +27 -15
  19. sglang/srt/managers/controller/tp_worker.py +31 -14
  20. sglang/srt/managers/detokenizer_manager.py +4 -2
  21. sglang/srt/managers/io_struct.py +1 -1
  22. sglang/srt/managers/tokenizer_manager.py +14 -13
  23. sglang/srt/model_config.py +6 -0
  24. sglang/srt/models/gemma2.py +436 -0
  25. sglang/srt/models/llama2.py +3 -3
  26. sglang/srt/models/llama_classification.py +10 -7
  27. sglang/srt/models/minicpm.py +373 -0
  28. sglang/srt/models/qwen2_moe.py +454 -0
  29. sglang/srt/openai_api_adapter.py +2 -2
  30. sglang/srt/openai_protocol.py +1 -1
  31. sglang/srt/server.py +17 -8
  32. sglang/srt/server_args.py +14 -16
  33. sglang/srt/utils.py +68 -35
  34. {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/METADATA +19 -13
  35. {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/RECORD +38 -35
  36. {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
  37. {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/WHEEL +0 -0
  38. {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,436 @@
1
+ # Adapted from:
2
+ # https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py
3
+ from typing import Iterable, Optional, Set, Tuple, Union
4
+
5
+ import torch
6
+ from torch import nn
7
+ from transformers import PretrainedConfig
8
+ from vllm.config import CacheConfig, LoRAConfig
9
+ from vllm.distributed import get_tensor_model_parallel_world_size
10
+
11
+ # FIXME: temporary solution, remove after next vllm release
12
+ from vllm.model_executor.custom_op import CustomOp
13
+ from vllm.model_executor.layers.activation import GeluAndMul
14
+
15
+ # from vllm.model_executor.layers.layernorm import GemmaRMSNorm
16
+ from vllm.model_executor.layers.linear import (
17
+ MergedColumnParallelLinear,
18
+ QKVParallelLinear,
19
+ RowParallelLinear,
20
+ )
21
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
22
+
23
+ # from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
24
+ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
25
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
27
+
28
+ from sglang.srt.layers.logits_processor import LogitsProcessor
29
+ from sglang.srt.layers.radix_attention import RadixAttention
30
+ from sglang.srt.managers.controller.model_runner import InputMetadata
31
+
32
+
33
+ class GemmaRMSNorm(CustomOp):
34
+ """RMS normalization for Gemma.
35
+
36
+ Two differences from the above RMSNorm:
37
+ 1. x * (1 + w) instead of x * w.
38
+ 2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ hidden_size: int,
44
+ eps: float = 1e-6,
45
+ ) -> None:
46
+ super().__init__()
47
+ self.weight = nn.Parameter(torch.zeros(hidden_size))
48
+ self.variance_epsilon = eps
49
+
50
+ def forward_native(
51
+ self,
52
+ x: torch.Tensor,
53
+ residual: Optional[torch.Tensor] = None,
54
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
55
+ """PyTorch-native implementation equivalent to forward()."""
56
+ orig_dtype = x.dtype
57
+ if residual is not None:
58
+ x = x + residual
59
+ residual = x
60
+
61
+ x = x.float()
62
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
63
+ x = x * torch.rsqrt(variance + self.variance_epsilon)
64
+ # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
65
+ # See https://github.com/huggingface/transformers/pull/29402
66
+ x = x * (1.0 + self.weight.float())
67
+ x = x.to(orig_dtype)
68
+ return x if residual is None else (x, residual)
69
+
70
+ def forward_cuda(
71
+ self,
72
+ x: torch.Tensor,
73
+ residual: Optional[torch.Tensor] = None,
74
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
75
+ # from vLLM: TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm.
76
+ return self.forward_native(x, residual)
77
+
78
+
79
+ # FIXME: temporary solution, remove after next vllm release
80
+ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
81
+
82
+
83
+ class GemmaRotaryEmbedding(RotaryEmbedding):
84
+ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
85
+ # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
86
+ inv_freq = 1.0 / (
87
+ base
88
+ ** (
89
+ torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float()
90
+ / self.rotary_dim
91
+ )
92
+ )
93
+ return inv_freq
94
+
95
+
96
+ class Gemma2MLP(nn.Module):
97
+ def __init__(
98
+ self,
99
+ hidden_size: int,
100
+ intermediate_size: int,
101
+ hidden_act: str,
102
+ hidden_activation: str,
103
+ quant_config: Optional[QuantizationConfig] = None,
104
+ ) -> None:
105
+ super().__init__()
106
+ self.gate_up_proj = MergedColumnParallelLinear(
107
+ hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
108
+ )
109
+ self.down_proj = RowParallelLinear(
110
+ intermediate_size, hidden_size, bias=False, quant_config=quant_config
111
+ )
112
+ if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"):
113
+ raise ValueError(
114
+ "Gemma2 uses `gelu_pytorch_tanh` as the hidden activation "
115
+ "function. Please set `hidden_act` and `hidden_activation` to "
116
+ "`gelu_pytorch_tanh`."
117
+ )
118
+ self.act_fn = GeluAndMul(approximate="tanh")
119
+
120
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
121
+ gate_up, _ = self.gate_up_proj(x)
122
+ x = self.act_fn(gate_up)
123
+ x, _ = self.down_proj(x)
124
+ return x
125
+
126
+
127
+ class Gemma2Attention(nn.Module):
128
+ def __init__(
129
+ self,
130
+ layer_idx: int,
131
+ config: PretrainedConfig,
132
+ hidden_size: int,
133
+ num_heads: int,
134
+ num_kv_heads: int,
135
+ head_dim: int,
136
+ max_position_embeddings: int,
137
+ rope_theta: float,
138
+ cache_config: Optional[CacheConfig] = None,
139
+ quant_config: Optional[QuantizationConfig] = None,
140
+ ) -> None:
141
+ super().__init__()
142
+ self.layer_idx = layer_idx
143
+ self.config = config
144
+ self.hidden_size = hidden_size
145
+ tp_size = get_tensor_model_parallel_world_size()
146
+ self.total_num_heads = num_heads
147
+ assert self.total_num_heads % tp_size == 0
148
+ self.num_heads = self.total_num_heads // tp_size
149
+ self.total_num_kv_heads = num_kv_heads
150
+ if self.total_num_kv_heads >= tp_size:
151
+ # Number of KV heads is greater than TP size, so we partition
152
+ # the KV heads across multiple tensor parallel GPUs.
153
+ assert self.total_num_kv_heads % tp_size == 0
154
+ else:
155
+ # Number of KV heads is less than TP size, so we replicate
156
+ # the KV heads across multiple tensor parallel GPUs.
157
+ assert tp_size % self.total_num_kv_heads == 0
158
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
159
+ self.head_dim = head_dim
160
+ self.q_size = self.num_heads * self.head_dim
161
+ self.kv_size = self.num_kv_heads * self.head_dim
162
+ self.scaling = config.query_pre_attn_scalar**-0.5
163
+ self.rope_theta = rope_theta
164
+
165
+ self.qkv_proj = QKVParallelLinear(
166
+ hidden_size,
167
+ self.head_dim,
168
+ self.total_num_heads,
169
+ self.total_num_kv_heads,
170
+ bias=config.attention_bias,
171
+ quant_config=quant_config,
172
+ )
173
+ self.o_proj = RowParallelLinear(
174
+ self.total_num_heads * self.head_dim,
175
+ hidden_size,
176
+ bias=config.attention_bias,
177
+ quant_config=quant_config,
178
+ )
179
+ # from vLLM: TODO(woosuk): Use the `get_rope` interface.
180
+ self.rotary_emb = GemmaRotaryEmbedding(
181
+ self.head_dim,
182
+ self.head_dim,
183
+ max_position_embeddings,
184
+ base=self.rope_theta,
185
+ is_neox_style=True,
186
+ dtype=torch.get_default_dtype(),
187
+ )
188
+
189
+ # from vLLM: FIXME(woosuk): While Gemma 2 uses sliding window attention for every
190
+ # odd layer, vLLM currently ignores it and uses global attention for
191
+ # all layers.
192
+ use_sliding_window = layer_idx % 2 == 1 and config.sliding_window is not None
193
+ del use_sliding_window # Unused.
194
+ self.attn = RadixAttention(
195
+ self.num_heads,
196
+ self.head_dim,
197
+ self.scaling,
198
+ num_kv_heads=self.num_kv_heads,
199
+ layer_id=layer_idx,
200
+ logit_cap=self.config.attn_logit_softcapping,
201
+ )
202
+
203
+ def forward(
204
+ self,
205
+ positions: torch.Tensor,
206
+ hidden_states: torch.Tensor,
207
+ input_metadata: InputMetadata,
208
+ ) -> torch.Tensor:
209
+ qkv, _ = self.qkv_proj(hidden_states)
210
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
211
+ q, k = self.rotary_emb(positions, q, k)
212
+ attn_output = self.attn(q, k, v, input_metadata)
213
+ output, _ = self.o_proj(attn_output)
214
+ return output
215
+
216
+
217
+ class Gemma2DecoderLayer(nn.Module):
218
+ def __init__(
219
+ self,
220
+ layer_idx: int,
221
+ config: PretrainedConfig,
222
+ cache_config: Optional[CacheConfig] = None,
223
+ quant_config: Optional[QuantizationConfig] = None,
224
+ ) -> None:
225
+ super().__init__()
226
+ self.hidden_size = config.hidden_size
227
+ self.self_attn = Gemma2Attention(
228
+ layer_idx=layer_idx,
229
+ config=config,
230
+ hidden_size=self.hidden_size,
231
+ num_heads=config.num_attention_heads,
232
+ num_kv_heads=config.num_key_value_heads,
233
+ head_dim=config.head_dim,
234
+ max_position_embeddings=config.max_position_embeddings,
235
+ rope_theta=config.rope_theta,
236
+ cache_config=cache_config,
237
+ quant_config=quant_config,
238
+ )
239
+ self.hidden_size = config.hidden_size
240
+ self.mlp = Gemma2MLP(
241
+ hidden_size=self.hidden_size,
242
+ intermediate_size=config.intermediate_size,
243
+ hidden_act=config.hidden_act,
244
+ hidden_activation=config.hidden_activation,
245
+ quant_config=quant_config,
246
+ )
247
+ self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
248
+ self.post_attention_layernorm = GemmaRMSNorm(
249
+ config.hidden_size, eps=config.rms_norm_eps
250
+ )
251
+ self.pre_feedforward_layernorm = GemmaRMSNorm(
252
+ config.hidden_size, eps=config.rms_norm_eps
253
+ )
254
+ self.post_feedforward_layernorm = GemmaRMSNorm(
255
+ config.hidden_size, eps=config.rms_norm_eps
256
+ )
257
+
258
+ def forward(
259
+ self,
260
+ positions: torch.Tensor,
261
+ hidden_states: torch.Tensor,
262
+ input_metadata: InputMetadata,
263
+ residual: Optional[torch.Tensor],
264
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
265
+ if residual is None:
266
+ residual = hidden_states
267
+ hidden_states = self.input_layernorm(hidden_states)
268
+ else:
269
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
270
+ hidden_states = self.self_attn(
271
+ positions=positions,
272
+ hidden_states=hidden_states,
273
+ input_metadata=input_metadata,
274
+ )
275
+ hidden_states = self.post_attention_layernorm(hidden_states)
276
+
277
+ hidden_states, residual = self.pre_feedforward_layernorm(
278
+ hidden_states, residual
279
+ )
280
+ hidden_states = self.mlp(hidden_states)
281
+ hidden_states = self.post_feedforward_layernorm(hidden_states)
282
+ return hidden_states, residual
283
+
284
+
285
+ class Gemma2Model(nn.Module):
286
+ def __init__(
287
+ self,
288
+ config: PretrainedConfig,
289
+ cache_config: Optional[CacheConfig] = None,
290
+ quant_config: Optional[QuantizationConfig] = None,
291
+ ) -> None:
292
+ super().__init__()
293
+ self.config = config
294
+
295
+ self.embed_tokens = VocabParallelEmbedding(
296
+ config.vocab_size,
297
+ config.hidden_size,
298
+ )
299
+ self.layers = nn.ModuleList(
300
+ [
301
+ Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config)
302
+ for layer_idx in range(config.num_hidden_layers)
303
+ ]
304
+ )
305
+ self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
306
+
307
+ # Normalize the embedding by sqrt(hidden_size)
308
+ # The normalizer's data type should be downcasted to the model's
309
+ # data type such as bfloat16, not float32.
310
+ # See https://github.com/huggingface/transformers/pull/29402
311
+ normalizer = self.config.hidden_size**0.5
312
+ self.register_buffer("normalizer", torch.tensor(normalizer))
313
+
314
+ def forward(
315
+ self,
316
+ input_ids: torch.Tensor,
317
+ positions: torch.Tensor,
318
+ input_metadata: InputMetadata,
319
+ input_embeds: torch.Tensor = None,
320
+ ) -> torch.Tensor:
321
+ if input_embeds is None:
322
+ hidden_states = self.embed_tokens(input_ids)
323
+ else:
324
+ hidden_states = input_embeds
325
+ normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=torch.float16)
326
+ hidden_states *= normalizer
327
+
328
+ residual = None
329
+ for i in range(len(self.layers)):
330
+ layer = self.layers[i]
331
+ hidden_states, residual = layer(
332
+ positions,
333
+ hidden_states,
334
+ input_metadata,
335
+ residual,
336
+ )
337
+ hidden_states, _ = self.norm(hidden_states, residual)
338
+ return hidden_states
339
+
340
+
341
+ class Gemma2ForCausalLM(nn.Module):
342
+ packed_modules_mapping = {
343
+ "qkv_proj": [
344
+ "q_proj",
345
+ "k_proj",
346
+ "v_proj",
347
+ ],
348
+ "gate_up_proj": [
349
+ "gate_proj",
350
+ "up_proj",
351
+ ],
352
+ }
353
+
354
+ # LoRA specific attributes
355
+ supported_lora_modules = [
356
+ "qkv_proj",
357
+ "o_proj",
358
+ "gate_up_proj",
359
+ "down_proj",
360
+ ]
361
+ # Gemma does not apply LoRA to the embedding layer.
362
+ embedding_modules = {}
363
+ embedding_padding_modules = []
364
+
365
+ def __init__(
366
+ self,
367
+ config: PretrainedConfig,
368
+ cache_config: Optional[CacheConfig] = None,
369
+ quant_config: Optional[QuantizationConfig] = None,
370
+ lora_config: Optional[LoRAConfig] = None,
371
+ ) -> None:
372
+ del lora_config # Unused.
373
+ super().__init__()
374
+ self.config = config
375
+ self.quant_config = quant_config
376
+ self.model = Gemma2Model(config, cache_config, quant_config)
377
+ self.logits_processor = LogitsProcessor(config)
378
+
379
+ @torch.no_grad()
380
+ def forward(
381
+ self,
382
+ input_ids: torch.Tensor,
383
+ positions: torch.Tensor,
384
+ input_metadata: InputMetadata,
385
+ input_embeds: torch.Tensor = None,
386
+ ) -> torch.Tensor:
387
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
388
+ return self.logits_processor(
389
+ input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
390
+ )
391
+
392
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
393
+ stacked_params_mapping = [
394
+ # (param_name, shard_name, shard_id)
395
+ ("qkv_proj", "q_proj", "q"),
396
+ ("qkv_proj", "k_proj", "k"),
397
+ ("qkv_proj", "v_proj", "v"),
398
+ ("gate_up_proj", "gate_proj", 0),
399
+ ("gate_up_proj", "up_proj", 1),
400
+ ]
401
+ params_dict = dict(self.named_parameters())
402
+ loaded_params: Set[str] = set()
403
+ for name, loaded_weight in weights:
404
+ for param_name, shard_name, shard_id in stacked_params_mapping:
405
+ if shard_name not in name:
406
+ continue
407
+ name = name.replace(shard_name, param_name)
408
+ # Skip loading extra bias for GPTQ models.
409
+ if name.endswith(".bias") and name not in params_dict:
410
+ continue
411
+ param = params_dict[name]
412
+ weight_loader = param.weight_loader
413
+ weight_loader(param, loaded_weight, shard_id)
414
+ break
415
+ else:
416
+ # lm_head is not used in vllm as it is tied with embed_token.
417
+ # To prevent errors, skip loading lm_head.weight.
418
+ if "lm_head.weight" in name:
419
+ continue
420
+ # Skip loading extra bias for GPTQ models.
421
+ if name.endswith(".bias") and name not in params_dict:
422
+ continue
423
+ param = params_dict[name]
424
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
425
+ weight_loader(param, loaded_weight)
426
+ loaded_params.add(name)
427
+
428
+ unloaded_params = params_dict.keys() - loaded_params
429
+ if unloaded_params:
430
+ raise RuntimeError(
431
+ "Some weights are not initialized from checkpoints: "
432
+ f"{unloaded_params}"
433
+ )
434
+
435
+
436
+ EntryClass = Gemma2ForCausalLM
@@ -163,9 +163,9 @@ class LlamaDecoderLayer(nn.Module):
163
163
  if rope_scaling is not None and getattr(
164
164
  config, "original_max_position_embeddings", None
165
165
  ):
166
- rope_scaling["original_max_position_embeddings"] = (
167
- config.original_max_position_embeddings
168
- )
166
+ rope_scaling[
167
+ "original_max_position_embeddings"
168
+ ] = config.original_max_position_embeddings
169
169
  rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
170
170
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
171
171
  self.self_attn = LlamaAttention(
@@ -5,14 +5,12 @@ import tqdm
5
5
  from torch import nn
6
6
  from transformers import LlamaConfig
7
7
  from vllm.config import CacheConfig
8
- from vllm.distributed import (
9
- get_tensor_model_parallel_rank,
10
- )
8
+ from vllm.distributed import get_tensor_model_parallel_rank
11
9
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
12
10
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
13
11
 
14
- from sglang.srt.managers.controller.model_runner import InputMetadata
15
12
  from sglang.srt.layers.logits_processor import LogitProcessorOutput
13
+ from sglang.srt.managers.controller.model_runner import InputMetadata
16
14
  from sglang.srt.models.llama2 import LlamaModel
17
15
 
18
16
 
@@ -28,7 +26,9 @@ class LlamaForClassification(nn.Module):
28
26
  self.quant_config = quant_config
29
27
  self.model = LlamaModel(config, quant_config=quant_config)
30
28
 
31
- self.classification_head = nn.Linear(config.hidden_size, config.classification_out_size)
29
+ self.classification_head = nn.Linear(
30
+ config.hidden_size, config.classification_out_size
31
+ )
32
32
  self.eos_token_id = config.eos_token_id
33
33
 
34
34
  def forward(
@@ -45,7 +45,9 @@ class LlamaForClassification(nn.Module):
45
45
 
46
46
  if scores.shape[0] != input_metadata.batch_size:
47
47
  print("Warning: the EOS tokens are missing in some sentences.")
48
- scores = torch.ones((input_metadata.batch_size, self.config.classification_out_size)).to(input_ids.device)
48
+ scores = torch.ones(
49
+ (input_metadata.batch_size, self.config.classification_out_size)
50
+ ).to(input_ids.device)
49
51
 
50
52
  return LogitProcessorOutput(
51
53
  next_token_logits=scores,
@@ -101,4 +103,5 @@ class LlamaForClassification(nn.Module):
101
103
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
102
104
  weight_loader(param, loaded_weight)
103
105
 
104
- EntryClass = LlamaForClassification
106
+
107
+ EntryClass = LlamaForClassification