sglang 0.4.0__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_offline_throughput.py +18 -6
  3. sglang/bench_one_batch.py +13 -0
  4. sglang/bench_serving.py +8 -1
  5. sglang/check_env.py +140 -48
  6. sglang/lang/backend/runtime_endpoint.py +1 -0
  7. sglang/lang/chat_template.py +32 -0
  8. sglang/llama3_eval.py +316 -0
  9. sglang/srt/constrained/outlines_backend.py +5 -0
  10. sglang/srt/constrained/xgrammar_backend.py +9 -6
  11. sglang/srt/layers/attention/__init__.py +5 -2
  12. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  13. sglang/srt/layers/attention/flashinfer_backend.py +22 -5
  14. sglang/srt/layers/attention/torch_native_backend.py +22 -8
  15. sglang/srt/layers/attention/triton_backend.py +38 -33
  16. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  17. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  18. sglang/srt/layers/ep_moe/__init__.py +0 -0
  19. sglang/srt/layers/ep_moe/kernels.py +349 -0
  20. sglang/srt/layers/ep_moe/layer.py +665 -0
  21. sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
  22. sglang/srt/layers/fused_moe_triton/layer.py +1 -1
  23. sglang/srt/layers/logits_processor.py +133 -95
  24. sglang/srt/layers/quantization/__init__.py +2 -47
  25. sglang/srt/layers/quantization/fp8.py +607 -0
  26. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  27. sglang/srt/layers/radix_attention.py +11 -2
  28. sglang/srt/layers/sampler.py +29 -5
  29. sglang/srt/layers/torchao_utils.py +58 -45
  30. sglang/srt/managers/detokenizer_manager.py +37 -17
  31. sglang/srt/managers/io_struct.py +39 -10
  32. sglang/srt/managers/schedule_batch.py +39 -24
  33. sglang/srt/managers/schedule_policy.py +64 -5
  34. sglang/srt/managers/scheduler.py +236 -197
  35. sglang/srt/managers/tokenizer_manager.py +99 -58
  36. sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
  37. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  38. sglang/srt/mem_cache/chunk_cache.py +2 -2
  39. sglang/srt/mem_cache/memory_pool.py +5 -1
  40. sglang/srt/mem_cache/radix_cache.py +12 -2
  41. sglang/srt/model_executor/cuda_graph_runner.py +39 -11
  42. sglang/srt/model_executor/model_runner.py +24 -9
  43. sglang/srt/model_parallel.py +67 -10
  44. sglang/srt/models/commandr.py +2 -2
  45. sglang/srt/models/deepseek_v2.py +87 -7
  46. sglang/srt/models/gemma2.py +34 -0
  47. sglang/srt/models/gemma2_reward.py +0 -1
  48. sglang/srt/models/granite.py +517 -0
  49. sglang/srt/models/grok.py +72 -13
  50. sglang/srt/models/llama.py +22 -5
  51. sglang/srt/models/llama_classification.py +11 -23
  52. sglang/srt/models/llama_reward.py +0 -2
  53. sglang/srt/models/llava.py +37 -14
  54. sglang/srt/models/mixtral.py +12 -9
  55. sglang/srt/models/phi3_small.py +0 -5
  56. sglang/srt/models/qwen2.py +20 -0
  57. sglang/srt/models/qwen2_moe.py +0 -5
  58. sglang/srt/models/torch_native_llama.py +0 -5
  59. sglang/srt/openai_api/adapter.py +4 -0
  60. sglang/srt/openai_api/protocol.py +9 -4
  61. sglang/srt/sampling/sampling_batch_info.py +9 -8
  62. sglang/srt/server.py +4 -4
  63. sglang/srt/server_args.py +62 -13
  64. sglang/srt/utils.py +57 -10
  65. sglang/test/test_utils.py +3 -2
  66. sglang/utils.py +10 -3
  67. sglang/version.py +1 -1
  68. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +15 -9
  69. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +72 -65
  70. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
  71. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
  72. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
@@ -2,18 +2,18 @@
2
2
  Common utilities for torch model parallelism.
3
3
  """
4
4
 
5
- from typing import Optional
5
+ from typing import Optional, Sequence
6
6
 
7
7
  import torch
8
+ import torch.nn as nn
8
9
  from torch.distributed.device_mesh import DeviceMesh
9
10
 
10
11
  try:
11
- from torch.distributed.tensor import DTensor, Shard
12
+ import torch.distributed.tensor as dt
12
13
  except ImportError:
13
14
  # torch 2.4 or older
14
- from torch.distributed._tensor import DTensor, Shard
15
+ import torch.distributed._tensor as dt
15
16
 
16
- from torch.distributed._functional_collectives import AsyncCollectiveTensor
17
17
  from torch.distributed.tensor.parallel import (
18
18
  ColwiseParallel,
19
19
  RowwiseParallel,
@@ -21,6 +21,50 @@ from torch.distributed.tensor.parallel import (
21
21
  )
22
22
 
23
23
 
24
+ def _shard_tensor(
25
+ full_tensor: torch.Tensor,
26
+ device_mesh: DeviceMesh,
27
+ placements: Sequence[dt.Shard],
28
+ ) -> "dt.DTensor":
29
+ """
30
+ Locally shards a full tensor based on indicated sharding arrangement, and
31
+ returns a DTensor containing the local shard.
32
+
33
+ .. warning:: This is a private API that is subject to change. It skips the
34
+ communication otherwise required by `distribute_tensor`. It is only
35
+ applicable to cases where all ranks have the same `full_tensor`. For
36
+ example, in distributed inference all ranks load from the same
37
+ checkpoint. This API will not check for data equality between ranks, it
38
+ is thus user's responsibility to ensure the `full_tensor` is the same
39
+ across ranks.
40
+
41
+ Args:
42
+ full_tensor (torch.Tensor): the full tensor to be sharded.
43
+ device_mesh (:class:`DeviceMesh`): DeviceMesh to place the
44
+ DTensor. Must have same dimension as the number of placements.
45
+ placements (Sequence[:class:`Shard`]): the placements that
46
+ describes how to place the local tensor on DeviceMesh.
47
+
48
+ Returns:
49
+ A :class:`DTensor` object with the shard as its local tensor.
50
+
51
+ Examples:
52
+ >>> # xdoctest: +SKIP("need world_size and rank")
53
+ >>> device_mesh = dist.init_device_mesh("cuda", (world_size,))
54
+ >>> full_tensor = torch.arange(world_size, device=f"cuda:{rank}")
55
+ >>> dtensor = _shard_tensor(full_tensor, device_mesh, [Shard(1)])
56
+ """
57
+ shape, offset = dt._utils.compute_local_shape_and_global_offset(
58
+ full_tensor.shape, device_mesh, placements
59
+ )
60
+ slices = [
61
+ slice(cur_offset, cur_offset + cur_shape)
62
+ for cur_shape, cur_offset in zip(shape, offset)
63
+ ]
64
+ local_tensor = full_tensor[slices]
65
+ return dt.DTensor.from_local(local_tensor, device_mesh, placements)
66
+
67
+
24
68
  class ColwiseParallelSharded(ColwiseParallel):
25
69
  """
26
70
  A version of ColwiseParallel where the local weight has been already
@@ -34,7 +78,7 @@ class ColwiseParallelSharded(ColwiseParallel):
34
78
  # means Colwise as Linear is input * weight^T + bias, where
35
79
  # weight would become Shard(1)
36
80
  for name, param in module.named_parameters():
37
- dtensor = DTensor.from_local(param, device_mesh, [Shard(0)])
81
+ dtensor = dt.DTensor.from_local(param, device_mesh, [dt.Shard(0)])
38
82
  dist_param = torch.nn.Parameter(dtensor, requires_grad=False)
39
83
  module.register_parameter(name, dist_param)
40
84
 
@@ -47,6 +91,23 @@ class RowwiseParallelMaybeWait(RowwiseParallel):
47
91
  AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`.
48
92
  """
49
93
 
94
+ def _partition_linear_fn(self, name, module, device_mesh):
95
+ # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
96
+ # means Rowwise as nn.Linear is input * weight^T + bias, where
97
+ # weight would become Shard(0)
98
+ module.register_parameter(
99
+ "weight",
100
+ nn.Parameter(_shard_tensor(module.weight, device_mesh, [dt.Shard(1)])),
101
+ )
102
+ if getattr(module, "bias", None) is not None:
103
+ # The Linear module has bias
104
+ module.register_parameter(
105
+ "bias",
106
+ nn.Parameter(
107
+ dt.distribute_tensor(module.bias, device_mesh, [dt.Replicate()])
108
+ ),
109
+ )
110
+
50
111
  @staticmethod
51
112
  def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
52
113
  outputs = super(
@@ -54,11 +115,7 @@ class RowwiseParallelMaybeWait(RowwiseParallel):
54
115
  )._prepare_output_fn(
55
116
  output_layouts, use_local_output, mod, outputs, device_mesh
56
117
  )
57
- # wait for the output to be ready
58
- if isinstance(outputs, AsyncCollectiveTensor):
59
- return outputs.wait()
60
- else:
61
- return outputs
118
+ return torch.distributed._functional_collectives.wait_tensor(outputs)
62
119
 
63
120
 
64
121
  def tensor_parallel(
@@ -62,10 +62,10 @@ from sglang.srt.layers.radix_attention import RadixAttention
62
62
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
63
63
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
64
64
  from sglang.srt.model_loader.weight_utils import default_weight_loader
65
- from sglang.srt.utils import set_weight_attrs
65
+ from sglang.srt.utils import get_compiler_backend, set_weight_attrs
66
66
 
67
67
 
68
- @torch.compile
68
+ @torch.compile(backend=get_compiler_backend())
69
69
  def layer_norm_func(hidden_states, weight, variance_epsilon):
70
70
  input_dtype = hidden_states.dtype
71
71
  hidden_states = hidden_states.to(torch.float32)
@@ -21,6 +21,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import PretrainedConfig
24
+ from vllm import _custom_ops as ops
24
25
  from vllm.distributed import (
25
26
  get_tensor_model_parallel_rank,
26
27
  get_tensor_model_parallel_world_size,
@@ -30,6 +31,7 @@ from vllm.distributed import (
30
31
  from vllm.model_executor.layers.rotary_embedding import get_rope
31
32
 
32
33
  from sglang.srt.layers.activation import SiluAndMul
34
+ from sglang.srt.layers.ep_moe.layer import EPMoE
33
35
  from sglang.srt.layers.fused_moe_triton import FusedMoE
34
36
  from sglang.srt.layers.layernorm import RMSNorm
35
37
  from sglang.srt.layers.linear import (
@@ -112,12 +114,12 @@ class DeepseekV2MoE(nn.Module):
112
114
  "Only silu is supported for now."
113
115
  )
114
116
 
115
- self.experts = FusedMoE(
117
+ MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
118
+ self.experts = MoEImpl(
116
119
  num_experts=config.n_routed_experts,
117
120
  top_k=config.num_experts_per_tok,
118
121
  hidden_size=config.hidden_size,
119
122
  intermediate_size=config.moe_intermediate_size,
120
- reduce_results=False,
121
123
  renormalize=config.norm_topk_prob,
122
124
  quant_config=quant_config,
123
125
  use_grouped_topk=True,
@@ -453,7 +455,7 @@ class DeepseekV2AttentionMLA(nn.Module):
453
455
  mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
454
456
  self.scaling = self.scaling * mscale * mscale
455
457
 
456
- self.attn = RadixAttention(
458
+ self.attn_mqa = RadixAttention(
457
459
  self.num_local_heads,
458
460
  self.kv_lora_rank + self.qk_rope_head_dim,
459
461
  self.scaling,
@@ -462,6 +464,15 @@ class DeepseekV2AttentionMLA(nn.Module):
462
464
  v_head_dim=self.kv_lora_rank,
463
465
  )
464
466
 
467
+ self.attn_mha = RadixAttention(
468
+ self.num_local_heads,
469
+ self.qk_nope_head_dim + self.qk_rope_head_dim,
470
+ self.scaling,
471
+ num_kv_heads=self.num_local_heads,
472
+ layer_id=layer_id,
473
+ v_head_dim=self.v_head_dim,
474
+ )
475
+
465
476
  self.w_kc = None
466
477
  self.w_vc = None
467
478
  self.w_scale = None
@@ -471,6 +482,63 @@ class DeepseekV2AttentionMLA(nn.Module):
471
482
  positions: torch.Tensor,
472
483
  hidden_states: torch.Tensor,
473
484
  forward_batch: ForwardBatch,
485
+ ) -> torch.Tensor:
486
+ # Use normal computation for prefill and use weight absorption for extend/decode
487
+ if (
488
+ forward_batch.forward_mode.is_extend()
489
+ and forward_batch.extend_prefix_lens.sum() == 0
490
+ ):
491
+ return self.forward_normal(positions, hidden_states, forward_batch)
492
+ else:
493
+ return self.forward_absorb(positions, hidden_states, forward_batch)
494
+
495
+ def forward_normal(
496
+ self,
497
+ positions: torch.Tensor,
498
+ hidden_states: torch.Tensor,
499
+ forward_batch: ForwardBatch,
500
+ ) -> torch.Tensor:
501
+ if self.q_lora_rank is not None:
502
+ q = self.q_a_proj(hidden_states)[0]
503
+ q = self.q_a_layernorm(q)
504
+ q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
505
+ else:
506
+ q = self.q_proj(hidden_states)[0].view(
507
+ -1, self.num_local_heads, self.qk_head_dim
508
+ )
509
+ _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
510
+ latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
511
+ kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
512
+ latent_cache = latent_cache.unsqueeze(1)
513
+ kv_a = self.kv_a_layernorm(kv_a.contiguous())
514
+ kv = self.kv_b_proj(kv_a)[0]
515
+ kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
516
+ k_nope = kv[..., : self.qk_nope_head_dim]
517
+ v = kv[..., self.qk_nope_head_dim :]
518
+ k_pe = latent_cache[:, :, self.kv_lora_rank :]
519
+ q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
520
+ q[..., self.qk_nope_head_dim :] = q_pe
521
+ k = torch.empty_like(q)
522
+ k[..., : self.qk_nope_head_dim] = k_nope
523
+ k[..., self.qk_nope_head_dim :] = k_pe
524
+
525
+ latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
526
+ latent_cache[:, :, self.kv_lora_rank :] = k_pe
527
+
528
+ # Save latent cache
529
+ forward_batch.token_to_kv_pool.set_kv_buffer(
530
+ self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
531
+ )
532
+ attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
533
+ attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
534
+ output, _ = self.o_proj(attn_output)
535
+ return output
536
+
537
+ def forward_absorb(
538
+ self,
539
+ positions: torch.Tensor,
540
+ hidden_states: torch.Tensor,
541
+ forward_batch: ForwardBatch,
474
542
  ) -> torch.Tensor:
475
543
  q_len = hidden_states.shape[0]
476
544
  q_input = hidden_states.new_empty(
@@ -508,7 +576,7 @@ class DeepseekV2AttentionMLA(nn.Module):
508
576
  q_input[..., self.kv_lora_rank :] = q_pe
509
577
  k_input[..., self.kv_lora_rank :] = k_pe
510
578
 
511
- attn_output = self.attn(q_input, k_input, v_input, forward_batch)
579
+ attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
512
580
  attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
513
581
 
514
582
  if self.w_vc.dtype == torch.float8_e4m3fn:
@@ -767,7 +835,8 @@ class DeepseekV2ForCausalLM(nn.Module):
767
835
 
768
836
  # Params for weights, fp8 weight scales, fp8 activation scales
769
837
  # (param_name, weight_name, expert_id, shard_id)
770
- expert_params_mapping = FusedMoE.make_expert_params_mapping(
838
+ MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
839
+ expert_params_mapping = MoEImpl.make_expert_params_mapping(
771
840
  ckpt_gate_proj_name="gate_proj",
772
841
  ckpt_down_proj_name="down_proj",
773
842
  ckpt_up_proj_name="up_proj",
@@ -828,14 +897,25 @@ class DeepseekV2ForCausalLM(nn.Module):
828
897
  if not global_server_args_dict["disable_mla"]:
829
898
  for layer_id in range(self.config.num_hidden_layers):
830
899
  self_attn = self.model.layers[layer_id].self_attn
831
- w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
900
+ if hasattr(self_attn.kv_b_proj, "qweight"):
901
+ # AWQ compatible
902
+ w = ops.awq_dequantize(
903
+ self_attn.kv_b_proj.qweight,
904
+ self_attn.kv_b_proj.scales,
905
+ self_attn.kv_b_proj.qzeros,
906
+ 0,
907
+ 0,
908
+ 0,
909
+ ).T
910
+ else:
911
+ w = self_attn.kv_b_proj.weight
912
+ w_kc, w_vc = w.unflatten(
832
913
  0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
833
914
  ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
834
915
  self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
835
916
  self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
836
917
  if hasattr(self_attn.kv_b_proj, "weight_scale"):
837
918
  self_attn.w_scale = self_attn.kv_b_proj.weight_scale
838
- del self_attn.kv_b_proj
839
919
 
840
920
 
841
921
  EntryClass = DeepseekV2ForCausalLM
@@ -355,6 +355,40 @@ class Gemma2ForCausalLM(nn.Module):
355
355
  input_ids, hidden_states, self.model.embed_tokens, forward_batch
356
356
  )
357
357
 
358
+ def get_hidden_dim(self, module_name):
359
+ # return input_dim, output_dim
360
+ if module_name in ["q_proj", "qkv_proj"]:
361
+ return (
362
+ self.config.hidden_size,
363
+ self.config.head_dim * self.config.num_attention_heads,
364
+ )
365
+ elif module_name in ["o_proj"]:
366
+ return (
367
+ self.config.head_dim * self.config.num_attention_heads,
368
+ self.config.hidden_size,
369
+ )
370
+ elif module_name in ["kv_proj"]:
371
+ return (
372
+ self.config.hidden_size,
373
+ self.config.head_dim * self.config.num_key_value_heads,
374
+ )
375
+ elif module_name == "gate_up_proj":
376
+ return self.config.hidden_size, self.config.intermediate_size
377
+ elif module_name == "down_proj":
378
+ return self.config.intermediate_size, self.config.hidden_size
379
+ else:
380
+ raise NotImplementedError()
381
+
382
+ def get_module_name(self, name):
383
+ params_mapping = {
384
+ "q_proj": "qkv_proj",
385
+ "k_proj": "qkv_proj",
386
+ "v_proj": "qkv_proj",
387
+ "gate_proj": "gate_up_proj",
388
+ "up_proj": "gate_up_proj",
389
+ }
390
+ return params_mapping.get(name, name)
391
+
358
392
  def get_attention_sliding_window_size(self):
359
393
  return get_attention_sliding_window_size(self.config)
360
394
 
@@ -32,7 +32,6 @@ class Gemma2ForSequenceClassification(nn.Module):
32
32
  ) -> None:
33
33
  super().__init__()
34
34
  self.config = config
35
- self.torchao_config = None
36
35
  self.quant_config = quant_config
37
36
  self.num_labels = config.num_labels
38
37
  self.model = Gemma2Model(config, quant_config=quant_config)