sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/api.py +13 -1
  2. sglang/bench_latency.py +10 -5
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/global_config.py +1 -1
  6. sglang/lang/backend/runtime_endpoint.py +60 -49
  7. sglang/lang/chat_template.py +10 -5
  8. sglang/lang/compiler.py +4 -0
  9. sglang/lang/interpreter.py +5 -2
  10. sglang/lang/ir.py +22 -4
  11. sglang/launch_server.py +8 -1
  12. sglang/srt/constrained/jump_forward.py +13 -2
  13. sglang/srt/conversation.py +50 -1
  14. sglang/srt/hf_transformers_utils.py +22 -23
  15. sglang/srt/layers/activation.py +24 -2
  16. sglang/srt/layers/decode_attention.py +338 -50
  17. sglang/srt/layers/extend_attention.py +3 -1
  18. sglang/srt/layers/fused_moe/__init__.py +1 -0
  19. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  20. sglang/srt/layers/fused_moe/layer.py +587 -0
  21. sglang/srt/layers/layernorm.py +3 -0
  22. sglang/srt/layers/logits_processor.py +64 -27
  23. sglang/srt/layers/radix_attention.py +41 -18
  24. sglang/srt/layers/sampler.py +154 -0
  25. sglang/srt/managers/controller_multi.py +2 -8
  26. sglang/srt/managers/controller_single.py +7 -10
  27. sglang/srt/managers/detokenizer_manager.py +20 -9
  28. sglang/srt/managers/io_struct.py +44 -11
  29. sglang/srt/managers/policy_scheduler.py +5 -2
  30. sglang/srt/managers/schedule_batch.py +59 -179
  31. sglang/srt/managers/tokenizer_manager.py +193 -84
  32. sglang/srt/managers/tp_worker.py +131 -50
  33. sglang/srt/mem_cache/memory_pool.py +82 -8
  34. sglang/srt/mm_utils.py +79 -7
  35. sglang/srt/model_executor/cuda_graph_runner.py +97 -28
  36. sglang/srt/model_executor/forward_batch_info.py +188 -82
  37. sglang/srt/model_executor/model_runner.py +269 -87
  38. sglang/srt/models/chatglm.py +6 -14
  39. sglang/srt/models/commandr.py +6 -2
  40. sglang/srt/models/dbrx.py +5 -1
  41. sglang/srt/models/deepseek.py +7 -3
  42. sglang/srt/models/deepseek_v2.py +12 -7
  43. sglang/srt/models/gemma.py +6 -2
  44. sglang/srt/models/gemma2.py +22 -8
  45. sglang/srt/models/gpt_bigcode.py +5 -1
  46. sglang/srt/models/grok.py +66 -398
  47. sglang/srt/models/internlm2.py +5 -1
  48. sglang/srt/models/llama2.py +7 -3
  49. sglang/srt/models/llama_classification.py +2 -2
  50. sglang/srt/models/llama_embedding.py +4 -0
  51. sglang/srt/models/llava.py +176 -59
  52. sglang/srt/models/minicpm.py +7 -3
  53. sglang/srt/models/mixtral.py +61 -255
  54. sglang/srt/models/mixtral_quant.py +6 -5
  55. sglang/srt/models/qwen.py +7 -4
  56. sglang/srt/models/qwen2.py +15 -5
  57. sglang/srt/models/qwen2_moe.py +7 -16
  58. sglang/srt/models/stablelm.py +6 -2
  59. sglang/srt/openai_api/adapter.py +149 -58
  60. sglang/srt/sampling/sampling_batch_info.py +209 -0
  61. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
  62. sglang/srt/server.py +107 -71
  63. sglang/srt/server_args.py +49 -15
  64. sglang/srt/utils.py +27 -18
  65. sglang/test/runners.py +38 -38
  66. sglang/test/simple_eval_common.py +9 -10
  67. sglang/test/simple_eval_gpqa.py +2 -1
  68. sglang/test/simple_eval_humaneval.py +2 -2
  69. sglang/test/simple_eval_math.py +2 -1
  70. sglang/test/simple_eval_mmlu.py +2 -1
  71. sglang/test/test_activation.py +55 -0
  72. sglang/test/test_programs.py +32 -5
  73. sglang/test/test_utils.py +37 -50
  74. sglang/version.py +1 -1
  75. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
  76. sglang-0.2.14.dist-info/RECORD +114 -0
  77. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
  78. sglang/launch_server_llavavid.py +0 -29
  79. sglang/srt/model_loader/model_loader.py +0 -292
  80. sglang/srt/model_loader/utils.py +0 -275
  81. sglang-0.2.12.dist-info/RECORD +0 -112
  82. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
  83. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -27,9 +27,7 @@ from vllm.distributed import (
27
27
  get_tensor_model_parallel_world_size,
28
28
  tensor_model_parallel_all_reduce,
29
29
  )
30
- from vllm.model_executor.layers.activation import SiluAndMul
31
30
  from vllm.model_executor.layers.fused_moe import fused_moe
32
- from vllm.model_executor.layers.layernorm import RMSNorm
33
31
  from vllm.model_executor.layers.linear import (
34
32
  MergedColumnParallelLinear,
35
33
  QKVParallelLinear,
@@ -44,8 +42,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
44
42
  )
45
43
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
46
44
 
45
+ from sglang.srt.layers.activation import SiluAndMul
46
+ from sglang.srt.layers.layernorm import RMSNorm
47
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
+ from sglang.srt.layers.sampler import Sampler
49
50
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
51
 
51
52
 
@@ -385,6 +386,7 @@ class DeepseekForCausalLM(nn.Module):
385
386
  config.vocab_size, config.hidden_size, quant_config=quant_config
386
387
  )
387
388
  self.logits_processor = LogitsProcessor(config)
389
+ self.sampler = Sampler()
388
390
 
389
391
  @torch.no_grad()
390
392
  def forward(
@@ -394,9 +396,11 @@ class DeepseekForCausalLM(nn.Module):
394
396
  input_metadata: InputMetadata,
395
397
  ) -> torch.Tensor:
396
398
  hidden_states = self.model(input_ids, positions, input_metadata)
397
- return self.logits_processor(
399
+ logits_output = self.logits_processor(
398
400
  input_ids, hidden_states, self.lm_head.weight, input_metadata
399
401
  )
402
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
403
+ return sample_output, logits_output
400
404
 
401
405
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
402
406
  stacked_params_mapping = [
@@ -26,9 +26,7 @@ from vllm.distributed import (
26
26
  get_tensor_model_parallel_world_size,
27
27
  tensor_model_parallel_all_reduce,
28
28
  )
29
- from vllm.model_executor.layers.activation import SiluAndMul
30
29
  from vllm.model_executor.layers.fused_moe import FusedMoE
31
- from vllm.model_executor.layers.layernorm import RMSNorm
32
30
  from vllm.model_executor.layers.linear import (
33
31
  ColumnParallelLinear,
34
32
  MergedColumnParallelLinear,
@@ -43,8 +41,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
43
41
  )
44
42
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
43
 
44
+ from sglang.srt.layers.activation import SiluAndMul
45
+ from sglang.srt.layers.layernorm import RMSNorm
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
+ from sglang.srt.layers.sampler import Sampler
48
49
  from sglang.srt.managers.schedule_batch import global_server_args_dict
49
50
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
51
 
@@ -445,11 +446,12 @@ class DeepseekV2AttentionMLA(nn.Module):
445
446
  q_nope_out = q_input[..., : self.kv_lora_rank]
446
447
  torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1))
447
448
 
448
- k_input = self.kv_a_proj_with_mqa(hidden_states)[0].unsqueeze(1)
449
- k_pe = k_input[..., self.kv_lora_rank :]
450
- v_input = k_input[..., : self.kv_lora_rank]
451
- v_input = self.kv_a_layernorm(v_input.contiguous())
449
+ latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
450
+ v_input = latent_cache[..., : self.kv_lora_rank]
451
+ v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
452
+ k_input = latent_cache.unsqueeze(1)
452
453
  k_input[..., : self.kv_lora_rank] = v_input
454
+ k_pe = k_input[..., self.kv_lora_rank :]
453
455
 
454
456
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
455
457
  q_input[..., self.kv_lora_rank :] = q_pe
@@ -631,6 +633,7 @@ class DeepseekV2ForCausalLM(nn.Module):
631
633
  config.vocab_size, config.hidden_size, quant_config=quant_config
632
634
  )
633
635
  self.logits_processor = LogitsProcessor(config)
636
+ self.sampler = Sampler()
634
637
 
635
638
  def forward(
636
639
  self,
@@ -639,9 +642,11 @@ class DeepseekV2ForCausalLM(nn.Module):
639
642
  input_metadata: InputMetadata,
640
643
  ) -> torch.Tensor:
641
644
  hidden_states = self.model(input_ids, positions, input_metadata)
642
- return self.logits_processor(
645
+ logits_output = self.logits_processor(
643
646
  input_ids, hidden_states, self.lm_head.weight, input_metadata
644
647
  )
648
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
649
+ return sample_output, logits_output
645
650
 
646
651
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
647
652
  stacked_params_mapping = [
@@ -24,7 +24,6 @@ from transformers import PretrainedConfig
24
24
  from vllm.config import CacheConfig, LoRAConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
26
  from vllm.model_executor.layers.activation import GeluAndMul
27
- from vllm.model_executor.layers.layernorm import RMSNorm
28
27
  from vllm.model_executor.layers.linear import (
29
28
  MergedColumnParallelLinear,
30
29
  QKVParallelLinear,
@@ -35,8 +34,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
35
34
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
36
35
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
37
36
 
37
+ from sglang.srt.layers.layernorm import RMSNorm
38
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
+ from sglang.srt.layers.sampler import Sampler
40
41
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
41
42
 
42
43
 
@@ -287,6 +288,7 @@ class GemmaForCausalLM(nn.Module):
287
288
  self.quant_config = quant_config
288
289
  self.model = GemmaModel(config, quant_config=quant_config)
289
290
  self.logits_processor = LogitsProcessor(config)
291
+ self.sampler = Sampler()
290
292
 
291
293
  @torch.no_grad()
292
294
  def forward(
@@ -297,9 +299,11 @@ class GemmaForCausalLM(nn.Module):
297
299
  input_embeds: torch.Tensor = None,
298
300
  ) -> torch.Tensor:
299
301
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
300
- return self.logits_processor(
302
+ logits_output = self.logits_processor(
301
303
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
302
304
  )
305
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
306
+ return (sample_output, logits_output)
303
307
 
304
308
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
305
309
  stacked_params_mapping = [
@@ -25,7 +25,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
25
25
 
26
26
  # FIXME: temporary solution, remove after next vllm release
27
27
  from vllm.model_executor.custom_op import CustomOp
28
- from vllm.model_executor.layers.activation import GeluAndMul
29
28
 
30
29
  # from vllm.model_executor.layers.layernorm import GemmaRMSNorm
31
30
  from vllm.model_executor.layers.linear import (
@@ -39,11 +38,19 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
39
38
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
40
39
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
40
 
41
+ from sglang.srt.layers.activation import GeluAndMul
42
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
+ from sglang.srt.layers.sampler import Sampler
44
45
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
46
 
46
47
 
48
+ # Aligned with HF's implementation, using sliding window inclusive with the last token
49
+ # SGLang assumes exclusive
50
+ def get_attention_sliding_window_size(config):
51
+ return config.sliding_window - 1
52
+
53
+
47
54
  class GemmaRMSNorm(CustomOp):
48
55
  """RMS normalization for Gemma.
49
56
 
@@ -129,7 +136,7 @@ class Gemma2MLP(nn.Module):
129
136
  "function. Please set `hidden_act` and `hidden_activation` to "
130
137
  "`gelu_pytorch_tanh`."
131
138
  )
132
- self.act_fn = GeluAndMul(approximate="tanh")
139
+ self.act_fn = GeluAndMul()
133
140
 
134
141
  def forward(self, x: torch.Tensor) -> torch.Tensor:
135
142
  gate_up, _ = self.gate_up_proj(x)
@@ -200,17 +207,18 @@ class Gemma2Attention(nn.Module):
200
207
  dtype=torch.get_default_dtype(),
201
208
  )
202
209
 
203
- # from vLLM: FIXME(woosuk): While Gemma 2 uses sliding window attention for every
204
- # odd layer, vLLM currently ignores it and uses global attention for
205
- # all layers.
206
- use_sliding_window = layer_idx % 2 == 1 and config.sliding_window is not None
207
- del use_sliding_window # Unused.
210
+ use_sliding_window = layer_idx % 2 == 0 and hasattr(config, "sliding_window")
208
211
  self.attn = RadixAttention(
209
212
  self.num_heads,
210
213
  self.head_dim,
211
214
  self.scaling,
212
215
  num_kv_heads=self.num_kv_heads,
213
216
  layer_id=layer_idx,
217
+ sliding_window_size=(
218
+ get_attention_sliding_window_size(config)
219
+ if use_sliding_window
220
+ else None
221
+ ),
214
222
  logit_cap=self.config.attn_logit_softcapping,
215
223
  )
216
224
 
@@ -389,6 +397,7 @@ class Gemma2ForCausalLM(nn.Module):
389
397
  self.quant_config = quant_config
390
398
  self.model = Gemma2Model(config, cache_config, quant_config)
391
399
  self.logits_processor = LogitsProcessor(config)
400
+ self.sampler = Sampler()
392
401
 
393
402
  @torch.no_grad()
394
403
  def forward(
@@ -399,9 +408,14 @@ class Gemma2ForCausalLM(nn.Module):
399
408
  input_embeds: torch.Tensor = None,
400
409
  ) -> torch.Tensor:
401
410
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
402
- return self.logits_processor(
411
+ logits_output = self.logits_processor(
403
412
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
404
413
  )
414
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
415
+ return sample_output, logits_output
416
+
417
+ def get_attention_sliding_window_size(self):
418
+ return get_attention_sliding_window_size(self.config)
405
419
 
406
420
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
407
421
  stacked_params_mapping = [
@@ -35,6 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
35
 
36
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
+ from sglang.srt.layers.sampler import Sampler
38
39
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
39
40
 
40
41
 
@@ -261,6 +262,7 @@ class GPTBigCodeForCausalLM(nn.Module):
261
262
  if lora_config:
262
263
  self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
263
264
  self.logits_processor = LogitsProcessor(config)
265
+ self.sampler = Sampler()
264
266
 
265
267
  @torch.no_grad()
266
268
  def forward(
@@ -270,9 +272,11 @@ class GPTBigCodeForCausalLM(nn.Module):
270
272
  input_metadata: InputMetadata,
271
273
  ) -> torch.Tensor:
272
274
  hidden_states = self.transformer(input_ids, positions, input_metadata)
273
- return self.logits_processor(
275
+ logits_output = self.logits_processor(
274
276
  input_ids, hidden_states, self.lm_head.weight, input_metadata
275
277
  )
278
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
279
+ return sample_output, logits_output
276
280
 
277
281
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
278
282
  params_dict = dict(self.named_parameters(remove_duplicate=False))