sglang 0.4.1.post6__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +24 -16
  4. sglang/bench_one_batch.py +51 -3
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +37 -28
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +15 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/model_config.py +1 -0
  13. sglang/srt/constrained/base_grammar_backend.py +21 -0
  14. sglang/srt/constrained/xgrammar_backend.py +8 -4
  15. sglang/srt/conversation.py +14 -1
  16. sglang/srt/distributed/__init__.py +3 -3
  17. sglang/srt/distributed/communication_op.py +2 -1
  18. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  21. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  22. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  23. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  24. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  25. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  26. sglang/srt/distributed/parallel_state.py +1 -1
  27. sglang/srt/distributed/utils.py +2 -1
  28. sglang/srt/entrypoints/engine.py +449 -0
  29. sglang/srt/entrypoints/http_server.py +579 -0
  30. sglang/srt/layers/activation.py +3 -3
  31. sglang/srt/layers/attention/flashinfer_backend.py +10 -9
  32. sglang/srt/layers/attention/triton_backend.py +4 -6
  33. sglang/srt/layers/attention/vision.py +204 -0
  34. sglang/srt/layers/dp_attention.py +69 -0
  35. sglang/srt/layers/linear.py +41 -5
  36. sglang/srt/layers/logits_processor.py +48 -63
  37. sglang/srt/layers/moe/ep_moe/layer.py +4 -4
  38. sglang/srt/layers/moe/fused_moe_native.py +69 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +29 -5
  41. sglang/srt/layers/parameter.py +2 -1
  42. sglang/srt/layers/quantization/__init__.py +20 -23
  43. sglang/srt/layers/quantization/fp8.py +6 -3
  44. sglang/srt/layers/quantization/modelopt_quant.py +1 -2
  45. sglang/srt/layers/quantization/w8a8_int8.py +1 -1
  46. sglang/srt/layers/radix_attention.py +2 -2
  47. sglang/srt/layers/rotary_embedding.py +1179 -31
  48. sglang/srt/layers/sampler.py +39 -1
  49. sglang/srt/layers/vocab_parallel_embedding.py +2 -2
  50. sglang/srt/lora/lora.py +1 -9
  51. sglang/srt/managers/configure_logging.py +3 -0
  52. sglang/srt/managers/data_parallel_controller.py +79 -72
  53. sglang/srt/managers/detokenizer_manager.py +23 -6
  54. sglang/srt/managers/image_processor.py +158 -2
  55. sglang/srt/managers/io_struct.py +25 -2
  56. sglang/srt/managers/schedule_batch.py +49 -22
  57. sglang/srt/managers/schedule_policy.py +26 -12
  58. sglang/srt/managers/scheduler.py +277 -178
  59. sglang/srt/managers/session_controller.py +1 -0
  60. sglang/srt/managers/tokenizer_manager.py +206 -121
  61. sglang/srt/managers/tp_worker.py +6 -4
  62. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  63. sglang/srt/managers/utils.py +44 -0
  64. sglang/srt/mem_cache/memory_pool.py +10 -32
  65. sglang/srt/metrics/collector.py +15 -6
  66. sglang/srt/model_executor/cuda_graph_runner.py +4 -6
  67. sglang/srt/model_executor/model_runner.py +37 -15
  68. sglang/srt/model_loader/loader.py +8 -6
  69. sglang/srt/model_loader/weight_utils.py +55 -2
  70. sglang/srt/models/baichuan.py +6 -6
  71. sglang/srt/models/chatglm.py +2 -2
  72. sglang/srt/models/commandr.py +3 -3
  73. sglang/srt/models/dbrx.py +4 -4
  74. sglang/srt/models/deepseek.py +3 -3
  75. sglang/srt/models/deepseek_v2.py +8 -8
  76. sglang/srt/models/exaone.py +2 -2
  77. sglang/srt/models/gemma.py +2 -2
  78. sglang/srt/models/gemma2.py +6 -24
  79. sglang/srt/models/gpt2.py +3 -5
  80. sglang/srt/models/gpt_bigcode.py +1 -1
  81. sglang/srt/models/granite.py +2 -2
  82. sglang/srt/models/grok.py +3 -3
  83. sglang/srt/models/internlm2.py +2 -2
  84. sglang/srt/models/llama.py +7 -5
  85. sglang/srt/models/minicpm.py +2 -2
  86. sglang/srt/models/minicpm3.py +6 -6
  87. sglang/srt/models/minicpmv.py +1238 -0
  88. sglang/srt/models/mixtral.py +3 -3
  89. sglang/srt/models/mixtral_quant.py +3 -3
  90. sglang/srt/models/mllama.py +2 -2
  91. sglang/srt/models/olmo.py +3 -3
  92. sglang/srt/models/olmo2.py +4 -4
  93. sglang/srt/models/olmoe.py +7 -13
  94. sglang/srt/models/phi3_small.py +2 -2
  95. sglang/srt/models/qwen.py +2 -2
  96. sglang/srt/models/qwen2.py +41 -4
  97. sglang/srt/models/qwen2_moe.py +3 -3
  98. sglang/srt/models/qwen2_vl.py +22 -122
  99. sglang/srt/models/stablelm.py +2 -2
  100. sglang/srt/models/torch_native_llama.py +3 -3
  101. sglang/srt/models/xverse.py +6 -6
  102. sglang/srt/models/xverse_moe.py +6 -6
  103. sglang/srt/openai_api/protocol.py +2 -0
  104. sglang/srt/sampling/custom_logit_processor.py +38 -0
  105. sglang/srt/sampling/sampling_batch_info.py +139 -4
  106. sglang/srt/sampling/sampling_params.py +3 -1
  107. sglang/srt/server.py +4 -1090
  108. sglang/srt/server_args.py +57 -14
  109. sglang/srt/utils.py +103 -65
  110. sglang/test/runners.py +8 -13
  111. sglang/test/test_programs.py +1 -1
  112. sglang/test/test_utils.py +3 -1
  113. sglang/utils.py +12 -2
  114. sglang/version.py +1 -1
  115. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +16 -5
  116. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +119 -115
  117. sglang/launch_server_llavavid.py +0 -25
  118. sglang/srt/constrained/__init__.py +0 -16
  119. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  120. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
  121. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,204 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from einops import rearrange, repeat
8
+
9
+ from sglang.srt.distributed import parallel_state
10
+ from sglang.srt.distributed import utils as dist_utils
11
+ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
12
+ context_attention_fwd,
13
+ )
14
+ from sglang.srt.layers.linear import (
15
+ ColumnParallelLinear,
16
+ QKVParallelLinear,
17
+ RowParallelLinear,
18
+ )
19
+ from sglang.srt.layers.quantization import QuantizationConfig
20
+
21
+
22
+ def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
23
+ if not interleaved:
24
+ x1, x2 = x.chunk(2, dim=-1)
25
+ return torch.cat((-x2, x1), dim=-1)
26
+ else:
27
+ x1, x2 = x[..., ::2], x[..., 1::2]
28
+ return rearrange(
29
+ torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
30
+ )
31
+
32
+
33
+ def apply_rotary_emb_torch(
34
+ x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
35
+ ) -> torch.Tensor:
36
+ """
37
+ x: (batch_size, seqlen, nheads, headdim)
38
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
39
+ """
40
+ ro_dim = cos.shape[-1] * 2
41
+ assert ro_dim <= x.shape[-1]
42
+ cos = repeat(
43
+ cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
44
+ )
45
+ sin = repeat(
46
+ sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
47
+ )
48
+ return torch.cat(
49
+ [
50
+ x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
51
+ x[..., ro_dim:],
52
+ ],
53
+ dim=-1,
54
+ )
55
+
56
+
57
+ def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
58
+ t_ = t.float()
59
+ cos = freqs.cos()
60
+ sin = freqs.sin()
61
+ output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
62
+ return output
63
+
64
+
65
+ class VisionAttention(nn.Module):
66
+ """Multi-headed attention without any cache, mostly used for ViT."""
67
+
68
+ def __init__(
69
+ self,
70
+ embed_dim: int,
71
+ num_heads: int,
72
+ projection_size: int,
73
+ use_qkv_parallel: bool,
74
+ quant_config: Optional[QuantizationConfig] = None,
75
+ prefix: str = "",
76
+ ):
77
+ super().__init__()
78
+ world_size = parallel_state.get_tensor_model_parallel_world_size()
79
+
80
+ self.hidden_size_per_attention_head = dist_utils.divide(
81
+ projection_size, num_heads
82
+ )
83
+ self.num_attention_heads_per_partition = dist_utils.divide(
84
+ num_heads, world_size
85
+ )
86
+ # self.tp_size = get_tensor_model_parallel_world_size()
87
+ # num_heads = self.num_heads_per_partition
88
+ self.use_qkv_parallel = use_qkv_parallel
89
+ if use_qkv_parallel:
90
+ self.head_dim = embed_dim // num_heads
91
+ self.qkv_proj = QKVParallelLinear(
92
+ hidden_size=embed_dim,
93
+ head_size=self.head_dim,
94
+ total_num_heads=num_heads,
95
+ quant_config=quant_config,
96
+ prefix=f"{prefix}.qkv_proj",
97
+ )
98
+ else:
99
+ self.qkv_proj = ColumnParallelLinear(
100
+ input_size=embed_dim,
101
+ output_size=3 * projection_size,
102
+ quant_config=quant_config,
103
+ prefix=f"{prefix}.qkv_proj",
104
+ )
105
+ self.proj = RowParallelLinear(
106
+ input_size=embed_dim,
107
+ output_size=embed_dim,
108
+ quant_config=quant_config,
109
+ prefix=f"{prefix}.out_proj",
110
+ )
111
+
112
+ def forward(
113
+ self,
114
+ x: torch.Tensor,
115
+ cu_seqlens: Optional[torch.Tensor] = None,
116
+ rotary_pos_emb: torch.Tensor = None,
117
+ ) -> torch.Tensor:
118
+ """
119
+ Input shape: [b, s, embed_dim]
120
+ Output shape: [s, b, num_heads * head_size]
121
+ """
122
+
123
+ bsz, s, _ = x.shape
124
+ if self.use_qkv_parallel:
125
+ # [b, s, embed_dim] --> [b, s, embed_dim]
126
+ qkv, _ = self.qkv_proj(x)
127
+ q, k, v = qkv.chunk(3, dim=-1)
128
+
129
+ # [b, s, embed_dim] --> [b * s, num_heads, head_size]
130
+ q, k, v = [
131
+ x.reshape(
132
+ bsz * s, self.num_attention_heads_per_partition, -1
133
+ ).contiguous()
134
+ for x in (q, k, v)
135
+ ]
136
+ else:
137
+ # [b, s, embed_dim] --> [s, b, embed_dim]
138
+ x = rearrange(x, "b s ... -> s b ...")
139
+ # [s, b, embed_dim] --> [s, b, head * 3 * head_dim]
140
+ qkv, _ = self.qkv_proj(x)
141
+ # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
142
+ new_x_shape = qkv.size()[:-1] + (
143
+ self.num_attention_heads_per_partition,
144
+ 3 * self.hidden_size_per_attention_head,
145
+ )
146
+ qkv = qkv.view(*new_x_shape)
147
+
148
+ # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
149
+ q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
150
+
151
+ # [s, b, head, head_dim] --> [b, s, head, head_dim]
152
+ q, k, v = [
153
+ rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
154
+ ]
155
+
156
+ if rotary_pos_emb is not None:
157
+ q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
158
+ k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
159
+
160
+ if self.use_qkv_parallel:
161
+ pass
162
+ else:
163
+ # [b, s, head, head_dim] --> [b * s, head, head_dim]
164
+ q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
165
+
166
+ # [b * s, num_heads, head_size]
167
+ output = torch.empty_like(q)
168
+
169
+ seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cuda()
170
+ max_seqlen = seq_lens.max().item()
171
+
172
+ context_attention_fwd(
173
+ q,
174
+ k,
175
+ v,
176
+ output,
177
+ cu_seqlens.cuda(),
178
+ seq_lens,
179
+ max_seqlen,
180
+ is_causal=False,
181
+ )
182
+
183
+ if self.use_qkv_parallel:
184
+
185
+ # [b * s, head, head_dim] --> [b, s, head * head_dim]
186
+ output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
187
+
188
+ # [b, s, head, head_dim] --> [b, s, head, head_dim]
189
+ output, _ = self.proj(output)
190
+ else:
191
+ # [b * s, head, head_dim] --> [b, s, head, head_dim]
192
+ context_layer = rearrange(output, "(b s) ... -> b s ...", b=bsz)
193
+
194
+ # [s, b, num_heads * head_size]
195
+ context_layer = rearrange(
196
+ context_layer, "b s h d -> s b (h d)"
197
+ ).contiguous()
198
+
199
+ # [s, b, num_heads * head_size] --> [s, b, num_heads * head_size]
200
+ output, _ = self.proj(context_layer)
201
+
202
+ output = output.view(bsz, s, -1)
203
+
204
+ return output
@@ -0,0 +1,69 @@
1
+ import torch
2
+
3
+ from sglang.srt.distributed import GroupCoordinator, get_tp_group
4
+
5
+ _ATTN_TP_GROUP = None
6
+ _ATTN_TP_RANK = None
7
+ _ATTN_TP_SIZE = None
8
+ _DP_RANK = None
9
+ _DP_SIZE = None
10
+
11
+
12
+ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
13
+ if not enable_dp_attention:
14
+ return tp_rank, tp_size, 0
15
+
16
+ attn_tp_size = tp_size // dp_size
17
+ dp_rank = tp_rank // attn_tp_size
18
+ attn_tp_rank = tp_rank % attn_tp_size
19
+ return attn_tp_rank, attn_tp_size, dp_rank
20
+
21
+
22
+ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
23
+ global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
24
+
25
+ _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
26
+ enable_dp_attention, tp_rank, tp_size, dp_size
27
+ )
28
+ _DP_SIZE = dp_size
29
+
30
+ tp_group = get_tp_group()
31
+ _ATTN_TP_GROUP = GroupCoordinator(
32
+ [
33
+ list(range(head, head + _ATTN_TP_SIZE))
34
+ for head in range(0, tp_size, _ATTN_TP_SIZE)
35
+ ],
36
+ tp_rank,
37
+ torch.distributed.get_backend(tp_group.device_group),
38
+ False,
39
+ False,
40
+ False,
41
+ False,
42
+ False,
43
+ group_name="attention_tp",
44
+ )
45
+
46
+
47
+ def get_attention_tp_group():
48
+ assert _ATTN_TP_GROUP is not None, "dp attention not initialized!"
49
+ return _ATTN_TP_GROUP
50
+
51
+
52
+ def get_attention_tp_rank():
53
+ assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
54
+ return _ATTN_TP_RANK
55
+
56
+
57
+ def get_attention_tp_size():
58
+ assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
59
+ return _ATTN_TP_SIZE
60
+
61
+
62
+ def get_attention_dp_rank():
63
+ assert _DP_RANK is not None, "dp attention not initialized!"
64
+ return _DP_RANK
65
+
66
+
67
+ def get_attention_dp_size():
68
+ assert _DP_SIZE is not None, "dp attention not initialized!"
69
+ return _DP_SIZE
@@ -7,7 +7,8 @@ from typing import Dict, List, Optional, Tuple
7
7
  import torch
8
8
  import torch.nn.functional as F
9
9
  from torch.nn.parameter import Parameter, UninitializedParameter
10
- from vllm.distributed import (
10
+
11
+ from sglang.srt.distributed import (
11
12
  divide,
12
13
  get_tensor_model_parallel_rank,
13
14
  get_tensor_model_parallel_world_size,
@@ -15,10 +16,6 @@ from vllm.distributed import (
15
16
  tensor_model_parallel_all_gather,
16
17
  tensor_model_parallel_all_reduce,
17
18
  )
18
-
19
- # Workaround: many QuantizationConfig still depends on this, so we have to use vLLM's LinearBase now.
20
- from vllm.model_executor.layers.linear import LinearBase
21
-
22
19
  from sglang.srt.layers.parameter import (
23
20
  BasevLLMParameter,
24
21
  PackedColumnParameter,
@@ -174,6 +171,45 @@ class UnquantizedLinearMethod(LinearMethodBase):
174
171
  return F.linear(x, layer.weight, bias)
175
172
 
176
173
 
174
+ class LinearBase(torch.nn.Module):
175
+ """Base linear layer.
176
+
177
+ Args:
178
+ input_size: input dimension of the linear layer.
179
+ output_size: output dimension of the linear layer.
180
+ bias: If true, add bias.
181
+ skip_bias_add: If true, skip adding bias but instead return it.
182
+ params_dtype: Data type for the parameters.
183
+ quant_config: Quantization configure.
184
+ """
185
+
186
+ def __init__(
187
+ self,
188
+ input_size: int,
189
+ output_size: int,
190
+ skip_bias_add: bool = False,
191
+ params_dtype: Optional[torch.dtype] = None,
192
+ quant_config: Optional[QuantizationConfig] = None,
193
+ prefix: str = "",
194
+ ):
195
+ super().__init__()
196
+
197
+ # Keep input parameters
198
+ self.input_size = input_size
199
+ self.output_size = output_size
200
+ self.skip_bias_add = skip_bias_add
201
+ if params_dtype is None:
202
+ params_dtype = torch.get_default_dtype()
203
+ self.params_dtype = params_dtype
204
+ if quant_config is None:
205
+ self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod()
206
+ else:
207
+ self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
208
+
209
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
210
+ raise NotImplementedError
211
+
212
+
177
213
  class ReplicatedLinear(LinearBase):
178
214
  """Replicated linear layer.
179
215
 
@@ -14,17 +14,18 @@
14
14
  """Logits processing."""
15
15
 
16
16
  import dataclasses
17
+ import logging
17
18
  from typing import List, Optional, Union
18
19
 
19
20
  import torch
20
21
  import triton
21
22
  import triton.language as tl
22
23
  from torch import nn
23
- from vllm.distributed import (
24
+
25
+ from sglang.srt.distributed import (
24
26
  get_tensor_model_parallel_world_size,
25
27
  tensor_model_parallel_all_gather,
26
28
  )
27
-
28
29
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
29
30
  from sglang.srt.model_executor.forward_batch_info import (
30
31
  CaptureHiddenMode,
@@ -32,6 +33,8 @@ from sglang.srt.model_executor.forward_batch_info import (
32
33
  ForwardMode,
33
34
  )
34
35
 
36
+ logger = logging.getLogger(__name__)
37
+
35
38
 
36
39
  @dataclasses.dataclass
37
40
  class LogitsProcessorOutput:
@@ -50,8 +53,6 @@ class LogitsProcessorOutput:
50
53
  next_token_top_logprobs_idx: Optional[List] = None
51
54
 
52
55
  ## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
53
- # The normlaized logprobs of prompts. shape: [#seq]
54
- normalized_prompt_logprobs: torch.Tensor = None
55
56
  # The logprobs of input tokens. shape: [#token]
56
57
  input_token_logprobs: torch.Tensor = None
57
58
  # The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
@@ -129,59 +130,70 @@ class LogitsProcessor(nn.Module):
129
130
  hidden_states,
130
131
  lm_head: VocabParallelEmbedding,
131
132
  logits_metadata: Union[LogitsMetadata, ForwardBatch],
132
- ):
133
+ ) -> LogitsProcessorOutput:
133
134
  if isinstance(logits_metadata, ForwardBatch):
134
135
  logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
135
136
 
136
137
  # Get the last hidden states and last logits for the next token prediction
137
138
  if (
138
- logits_metadata.forward_mode.is_decode()
139
+ logits_metadata.forward_mode.is_decode_or_idle()
139
140
  or logits_metadata.forward_mode.is_target_verify()
140
141
  ):
141
- last_index = None
142
- last_hidden = hidden_states
143
- else:
142
+ pruned_states = hidden_states
143
+ sample_indices = None
144
+ elif (
145
+ logits_metadata.forward_mode.is_extend()
146
+ and not logits_metadata.extend_return_logprob
147
+ ):
148
+ # Prefill without input logprobs.
144
149
  last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
145
- last_hidden = hidden_states[last_index]
150
+ pruned_states = hidden_states[last_index]
151
+ sample_indices = None
152
+ else:
153
+ # Slice the requested tokens to compute logprob
154
+ sample_index_pt = -1
155
+ sample_indices = []
156
+ pt, pruned_states, pruned_input_ids = 0, [], []
157
+ for start_len, extend_len in zip(
158
+ logits_metadata.extend_logprob_start_lens_cpu,
159
+ logits_metadata.extend_seq_lens_cpu,
160
+ ):
161
+ pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
162
+ sample_index_pt += extend_len - start_len
163
+ sample_indices.append(sample_index_pt)
164
+ pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
165
+ pt += extend_len
166
+
167
+ pruned_states = torch.cat(pruned_states)
168
+
169
+ # Compute logits for both input and sampled tokens.
170
+ logits = self._get_logits(pruned_states, lm_head, logits_metadata)
171
+ sampled_logits = (
172
+ logits[sample_indices] if sample_indices is not None else logits
173
+ )
146
174
 
147
- # Compute logits
148
- last_logits = self._get_logits(last_hidden, lm_head)
149
175
  if (
150
176
  not logits_metadata.extend_return_logprob
151
177
  or logits_metadata.capture_hidden_mode.need_capture()
152
178
  ):
153
179
  # Decode mode or extend mode without return_logprob.
154
180
  return LogitsProcessorOutput(
155
- next_token_logits=last_logits,
181
+ next_token_logits=sampled_logits,
156
182
  hidden_states=(
157
183
  hidden_states
158
184
  if logits_metadata.capture_hidden_mode.is_full()
159
185
  else (
160
- last_hidden
186
+ pruned_states
161
187
  if logits_metadata.capture_hidden_mode.is_last()
162
188
  else None
163
189
  )
164
190
  ),
165
191
  )
166
192
  else:
167
- # Slice the requested tokens to compute logprob
168
- pt, pruned_states, pruned_input_ids = 0, [], []
169
- for start_len, extend_len in zip(
170
- logits_metadata.extend_logprob_start_lens_cpu,
171
- logits_metadata.extend_seq_lens_cpu,
172
- ):
173
- pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
174
- pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
175
- pt += extend_len
176
-
177
- # Compute the logits of all required tokens
178
- pruned_states = torch.cat(pruned_states)
179
- del hidden_states
180
- input_token_logits = self._get_logits(pruned_states, lm_head)
181
- del pruned_states
193
+ input_logprobs = logits
194
+ del hidden_states, logits
182
195
 
183
196
  # Normalize the logprob w/o temperature, top-p
184
- input_logprobs = input_token_logits
185
197
  input_logprobs = self.compute_temp_top_p_normalized_logprobs(
186
198
  input_logprobs, logits_metadata
187
199
  )
@@ -195,25 +207,18 @@ class LogitsProcessor(nn.Module):
195
207
  else:
196
208
  input_top_logprobs_val = input_top_logprobs_idx = None
197
209
 
198
- # Compute the normalized logprobs for the requested tokens.
199
- # Note that we pad a zero at the end for easy batching.
200
210
  input_token_logprobs = input_logprobs[
201
- torch.arange(input_logprobs.shape[0], device="cuda"),
211
+ torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
202
212
  torch.cat(
203
213
  [
204
214
  torch.cat(pruned_input_ids)[1:],
205
- torch.tensor([0], device="cuda"),
215
+ torch.tensor([0], device=input_logprobs.device),
206
216
  ]
207
217
  ),
208
218
  ]
209
- normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
210
- input_token_logprobs,
211
- logits_metadata,
212
- )
213
219
 
214
220
  return LogitsProcessorOutput(
215
- next_token_logits=last_logits,
216
- normalized_prompt_logprobs=normalized_prompt_logprobs,
221
+ next_token_logits=sampled_logits,
217
222
  input_token_logprobs=input_token_logprobs,
218
223
  input_top_logprobs_val=input_top_logprobs_val,
219
224
  input_top_logprobs_idx=input_top_logprobs_idx,
@@ -223,8 +228,11 @@ class LogitsProcessor(nn.Module):
223
228
  self,
224
229
  hidden_states: torch.Tensor,
225
230
  lm_head: VocabParallelEmbedding,
231
+ logits_metadata: LogitsMetadata,
226
232
  embedding_bias: Optional[torch.Tensor] = None,
227
233
  ) -> torch.Tensor:
234
+ """Get logits from hidden_states."""
235
+
228
236
  if hasattr(lm_head, "weight"):
229
237
  logits = torch.matmul(hidden_states, lm_head.weight.T)
230
238
  else:
@@ -237,8 +245,6 @@ class LogitsProcessor(nn.Module):
237
245
  if self.do_tensor_parallel_all_gather:
238
246
  logits = tensor_model_parallel_all_gather(logits)
239
247
 
240
- # Compute the normalized logprobs for the requested tokens.
241
- # Note that we pad a zero at the end for easy batching.
242
248
  logits = logits[:, : self.config.vocab_size].float()
243
249
 
244
250
  if self.final_logit_softcapping:
@@ -246,27 +252,6 @@ class LogitsProcessor(nn.Module):
246
252
 
247
253
  return logits
248
254
 
249
- @staticmethod
250
- def _get_normalized_prompt_logprobs(
251
- input_token_logprobs: torch.Tensor,
252
- logits_metadata: LogitsMetadata,
253
- ):
254
- logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
255
- pruned_lens = torch.tensor(
256
- logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
257
- )
258
-
259
- start = torch.zeros_like(pruned_lens)
260
- start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
261
- end = torch.clamp(
262
- start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
263
- )
264
- sum_logp = (
265
- logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
266
- )
267
- normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
268
- return normalized_prompt_logprobs
269
-
270
255
  @staticmethod
271
256
  def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
272
257
  max_k = max(logits_metadata.top_logprobs_nums)
@@ -4,13 +4,12 @@ from typing import Callable, List, Optional, Tuple
4
4
  import torch
5
5
  from torch.nn import Module
6
6
  from vllm import _custom_ops as ops
7
- from vllm.distributed import (
7
+ from vllm.model_executor.custom_op import CustomOp
8
+
9
+ from sglang.srt.distributed import (
8
10
  get_tensor_model_parallel_rank,
9
11
  get_tensor_model_parallel_world_size,
10
12
  )
11
- from vllm.model_executor.custom_op import CustomOp
12
- from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
13
-
14
13
  from sglang.srt.layers.custom_op_util import register_custom_op
15
14
  from sglang.srt.layers.moe.ep_moe.kernels import (
16
15
  grouped_gemm_triton,
@@ -25,6 +24,7 @@ from sglang.srt.layers.quantization.base_config import (
25
24
  QuantizationConfig,
26
25
  QuantizeMethodBase,
27
26
  )
27
+ from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
28
28
  from sglang.srt.utils import is_hip, set_weight_attrs
29
29
 
30
30
  logger = logging.getLogger(__name__)
@@ -8,6 +8,7 @@ from typing import Callable, Optional
8
8
  import torch
9
9
  from torch.nn import functional as F
10
10
 
11
+ from sglang.srt.layers.activation import SiluAndMul
11
12
  from sglang.srt.layers.moe.topk import select_experts
12
13
 
13
14
 
@@ -44,3 +45,71 @@ def fused_moe_forward_native(
44
45
  x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
45
46
  expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
46
47
  return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
48
+
49
+
50
+ def moe_forward_native(
51
+ layer: torch.nn.Module,
52
+ x: torch.Tensor,
53
+ use_grouped_topk: bool,
54
+ top_k: int,
55
+ router_logits: torch.Tensor,
56
+ renormalize: bool,
57
+ topk_group: Optional[int] = None,
58
+ num_expert_group: Optional[int] = None,
59
+ custom_routing_function: Optional[Callable] = None,
60
+ correction_bias: Optional[torch.Tensor] = None,
61
+ ) -> torch.Tensor:
62
+
63
+ topk_weights, topk_ids = select_experts(
64
+ hidden_states=x,
65
+ router_logits=router_logits,
66
+ use_grouped_topk=use_grouped_topk,
67
+ top_k=top_k,
68
+ renormalize=renormalize,
69
+ topk_group=topk_group,
70
+ num_expert_group=num_expert_group,
71
+ custom_routing_function=custom_routing_function,
72
+ correction_bias=correction_bias,
73
+ torch_native=True,
74
+ )
75
+
76
+ # Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589
77
+ len_experts = layer.num_experts
78
+
79
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts))
80
+ cnts.scatter_(1, topk_ids.to(torch.int64), 1)
81
+ tokens_per_expert = cnts.sum(dim=0)
82
+ idxs = topk_ids.view(-1).argsort()
83
+
84
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
85
+ tokens_per_expert = tokens_per_expert.cpu().numpy()
86
+
87
+ outputs = []
88
+ start_idx = 0
89
+ for i, num_tokens in enumerate(tokens_per_expert):
90
+ end_idx = start_idx + num_tokens
91
+ if num_tokens == 0:
92
+ continue
93
+ tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
94
+
95
+ layer_w13_weight = layer.w13_weight[i]
96
+ layer_w2_weight = layer.w2_weight[i]
97
+
98
+ gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
99
+ gate_up = SiluAndMul()(gate_up)
100
+ expert_out = F.linear(gate_up, layer_w2_weight)
101
+ outputs.append(expert_out)
102
+ start_idx = end_idx
103
+
104
+ outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
105
+ new_x = torch.empty_like(outs)
106
+
107
+ new_x[idxs] = outs
108
+ final_out = (
109
+ new_x.view(*topk_ids.shape, -1)
110
+ .type(topk_weights.dtype)
111
+ .mul_(topk_weights.unsqueeze(dim=-1))
112
+ .sum(dim=1)
113
+ .type(new_x.dtype)
114
+ )
115
+ return final_out
@@ -15,15 +15,18 @@ from vllm import _custom_ops as ops
15
15
 
16
16
  from sglang.srt.layers.moe.topk import select_experts
17
17
  from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
18
- from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
18
+ from sglang.srt.utils import (
19
+ direct_register_custom_op,
20
+ get_device_name,
21
+ is_cuda_available,
22
+ is_hip,
23
+ )
19
24
 
20
- is_hip_flag = False
21
- if not is_hip():
25
+ is_cuda = is_cuda_available()
26
+ is_hip_flag = is_hip()
27
+ if is_cuda:
22
28
  from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
23
29
 
24
- is_hip_flag = False
25
- else:
26
- is_hip_flag = True
27
30
 
28
31
  logger = logging.getLogger(__name__)
29
32
  padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0