sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +24 -16
  4. sglang/bench_one_batch.py +51 -3
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +37 -28
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +15 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/model_config.py +16 -6
  13. sglang/srt/constrained/base_grammar_backend.py +21 -0
  14. sglang/srt/constrained/xgrammar_backend.py +8 -4
  15. sglang/srt/conversation.py +14 -1
  16. sglang/srt/distributed/__init__.py +3 -3
  17. sglang/srt/distributed/communication_op.py +2 -1
  18. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  21. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  22. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  23. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  24. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  25. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  26. sglang/srt/distributed/parallel_state.py +1 -1
  27. sglang/srt/distributed/utils.py +2 -1
  28. sglang/srt/entrypoints/engine.py +449 -0
  29. sglang/srt/entrypoints/http_server.py +579 -0
  30. sglang/srt/layers/activation.py +3 -3
  31. sglang/srt/layers/attention/flashinfer_backend.py +27 -12
  32. sglang/srt/layers/attention/triton_backend.py +4 -6
  33. sglang/srt/layers/attention/vision.py +204 -0
  34. sglang/srt/layers/dp_attention.py +69 -0
  35. sglang/srt/layers/linear.py +76 -102
  36. sglang/srt/layers/logits_processor.py +48 -63
  37. sglang/srt/layers/moe/ep_moe/layer.py +4 -4
  38. sglang/srt/layers/moe/fused_moe_native.py +69 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -14
  41. sglang/srt/layers/moe/topk.py +4 -2
  42. sglang/srt/layers/parameter.py +26 -17
  43. sglang/srt/layers/quantization/__init__.py +22 -23
  44. sglang/srt/layers/quantization/fp8.py +112 -55
  45. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  46. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +2 -3
  48. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  49. sglang/srt/layers/radix_attention.py +2 -0
  50. sglang/srt/layers/rotary_embedding.py +1179 -31
  51. sglang/srt/layers/sampler.py +39 -1
  52. sglang/srt/layers/vocab_parallel_embedding.py +17 -4
  53. sglang/srt/lora/lora.py +1 -9
  54. sglang/srt/managers/configure_logging.py +46 -0
  55. sglang/srt/managers/data_parallel_controller.py +79 -72
  56. sglang/srt/managers/detokenizer_manager.py +23 -8
  57. sglang/srt/managers/image_processor.py +158 -2
  58. sglang/srt/managers/io_struct.py +54 -15
  59. sglang/srt/managers/schedule_batch.py +49 -22
  60. sglang/srt/managers/schedule_policy.py +26 -12
  61. sglang/srt/managers/scheduler.py +319 -181
  62. sglang/srt/managers/session_controller.py +1 -0
  63. sglang/srt/managers/tokenizer_manager.py +303 -158
  64. sglang/srt/managers/tp_worker.py +6 -4
  65. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  66. sglang/srt/managers/utils.py +44 -0
  67. sglang/srt/mem_cache/memory_pool.py +110 -77
  68. sglang/srt/metrics/collector.py +25 -11
  69. sglang/srt/model_executor/cuda_graph_runner.py +4 -6
  70. sglang/srt/model_executor/model_runner.py +80 -21
  71. sglang/srt/model_loader/loader.py +8 -6
  72. sglang/srt/model_loader/weight_utils.py +55 -2
  73. sglang/srt/models/baichuan.py +6 -6
  74. sglang/srt/models/chatglm.py +2 -2
  75. sglang/srt/models/commandr.py +3 -3
  76. sglang/srt/models/dbrx.py +4 -4
  77. sglang/srt/models/deepseek.py +3 -3
  78. sglang/srt/models/deepseek_v2.py +8 -8
  79. sglang/srt/models/exaone.py +2 -2
  80. sglang/srt/models/gemma.py +2 -2
  81. sglang/srt/models/gemma2.py +6 -24
  82. sglang/srt/models/gpt2.py +3 -5
  83. sglang/srt/models/gpt_bigcode.py +1 -1
  84. sglang/srt/models/granite.py +2 -2
  85. sglang/srt/models/grok.py +3 -3
  86. sglang/srt/models/internlm2.py +2 -2
  87. sglang/srt/models/llama.py +41 -4
  88. sglang/srt/models/minicpm.py +2 -2
  89. sglang/srt/models/minicpm3.py +6 -6
  90. sglang/srt/models/minicpmv.py +1238 -0
  91. sglang/srt/models/mixtral.py +3 -3
  92. sglang/srt/models/mixtral_quant.py +3 -3
  93. sglang/srt/models/mllama.py +2 -2
  94. sglang/srt/models/olmo.py +3 -3
  95. sglang/srt/models/olmo2.py +4 -4
  96. sglang/srt/models/olmoe.py +7 -13
  97. sglang/srt/models/phi3_small.py +2 -2
  98. sglang/srt/models/qwen.py +2 -2
  99. sglang/srt/models/qwen2.py +52 -4
  100. sglang/srt/models/qwen2_eagle.py +131 -0
  101. sglang/srt/models/qwen2_moe.py +3 -3
  102. sglang/srt/models/qwen2_vl.py +22 -122
  103. sglang/srt/models/stablelm.py +2 -2
  104. sglang/srt/models/torch_native_llama.py +3 -3
  105. sglang/srt/models/xverse.py +6 -6
  106. sglang/srt/models/xverse_moe.py +6 -6
  107. sglang/srt/openai_api/protocol.py +2 -0
  108. sglang/srt/sampling/custom_logit_processor.py +38 -0
  109. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  110. sglang/srt/sampling/sampling_batch_info.py +153 -9
  111. sglang/srt/sampling/sampling_params.py +4 -2
  112. sglang/srt/server.py +4 -1037
  113. sglang/srt/server_args.py +84 -32
  114. sglang/srt/speculative/eagle_worker.py +1 -0
  115. sglang/srt/torch_memory_saver_adapter.py +59 -0
  116. sglang/srt/utils.py +130 -63
  117. sglang/test/runners.py +8 -13
  118. sglang/test/test_programs.py +1 -1
  119. sglang/test/test_utils.py +3 -1
  120. sglang/utils.py +12 -2
  121. sglang/version.py +1 -1
  122. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +26 -13
  123. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +126 -117
  124. sglang/launch_server_llavavid.py +0 -25
  125. sglang/srt/constrained/__init__.py +0 -16
  126. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  127. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
  128. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
  129. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/linear.py
1
+ """Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
2
2
 
3
3
  import logging
4
4
  from abc import abstractmethod
@@ -7,7 +7,8 @@ from typing import Dict, List, Optional, Tuple
7
7
  import torch
8
8
  import torch.nn.functional as F
9
9
  from torch.nn.parameter import Parameter, UninitializedParameter
10
- from vllm.distributed import (
10
+
11
+ from sglang.srt.distributed import (
11
12
  divide,
12
13
  get_tensor_model_parallel_rank,
13
14
  get_tensor_model_parallel_world_size,
@@ -15,17 +16,12 @@ from vllm.distributed import (
15
16
  tensor_model_parallel_all_gather,
16
17
  tensor_model_parallel_all_reduce,
17
18
  )
18
-
19
- # workaround
20
- from vllm.model_executor.layers.linear import LinearBase
21
-
22
19
  from sglang.srt.layers.parameter import (
23
20
  BasevLLMParameter,
24
21
  PackedColumnParameter,
25
22
  PackedvLLMParameter,
26
23
  PerTensorScaleParameter,
27
24
  RowvLLMParameter,
28
- _ColumnvLLMParameter,
29
25
  )
30
26
  from sglang.srt.layers.quantization.base_config import (
31
27
  QuantizationConfig,
@@ -43,9 +39,13 @@ WEIGHT_LOADER_V2_SUPPORTED = [
43
39
  "GPTQMarlinLinearMethod",
44
40
  "Fp8LinearMethod",
45
41
  "MarlinLinearMethod",
46
- "GPTQLinearMethod",
47
42
  "QQQLinearMethod",
43
+ "GPTQMarlin24LinearMethod",
44
+ "TPUInt8LinearMethod",
45
+ "GPTQLinearMethod",
46
+ "FBGEMMFp8LinearMethod",
48
47
  "ModelOptFp8LinearMethod",
48
+ "IPEXAWQLinearMethod",
49
49
  ]
50
50
 
51
51
 
@@ -95,62 +95,6 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
95
95
  return param[shard_id], loaded_weight
96
96
 
97
97
 
98
- def load_column_qkv_weight(
99
- self, loaded_weight, num_heads, shard_id, shard_offset, shard_size, tp_rank
100
- ):
101
- if (
102
- isinstance(self, (PackedColumnParameter, PackedvLLMParameter))
103
- and self.output_dim == self.packed_dim
104
- ):
105
- shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
106
- shard_offset=shard_offset, shard_size=shard_size
107
- )
108
-
109
- param_data = self.data
110
- shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
111
- param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
112
- loaded_weight = loaded_weight.narrow(
113
- self.output_dim, shard_id * shard_size, shard_size
114
- )
115
-
116
- assert param_data.shape == loaded_weight.shape
117
- param_data.copy_(loaded_weight)
118
-
119
-
120
- def load_column_parallel_weight(
121
- self, loaded_weight: torch.Tensor, tp_rank, use_presharded_weights: bool = False
122
- ):
123
- if isinstance(self, _ColumnvLLMParameter):
124
- if not use_presharded_weights:
125
- shard_size = self.data.shape[self.output_dim]
126
- loaded_weight = loaded_weight.narrow(
127
- self.output_dim, tp_rank * shard_size, shard_size
128
- )
129
- assert self.data.shape == loaded_weight.shape
130
- self.data.copy_(loaded_weight)
131
- else:
132
- self.data.copy_(loaded_weight)
133
-
134
-
135
- def load_row_parallel_weight(
136
- self, loaded_weight: torch.Tensor, tp_rank, use_presharded_weights: bool = False
137
- ):
138
- if isinstance(self, RowvLLMParameter):
139
- if not use_presharded_weights:
140
- shard_size = self.data.shape[self.input_dim]
141
- loaded_weight = loaded_weight.narrow(
142
- self.input_dim, tp_rank * shard_size, shard_size
143
- )
144
-
145
- if len(loaded_weight.shape) == 0:
146
- loaded_weight = loaded_weight.reshape(1)
147
-
148
- assert self.data.shape == loaded_weight.shape
149
- self.data.copy_(loaded_weight)
150
- else:
151
- self.data.copy_(loaded_weight)
152
-
153
-
154
98
  class LinearMethodBase(QuantizeMethodBase):
155
99
  """Base class for different (maybe quantized) linear methods."""
156
100
 
@@ -227,6 +171,45 @@ class UnquantizedLinearMethod(LinearMethodBase):
227
171
  return F.linear(x, layer.weight, bias)
228
172
 
229
173
 
174
+ class LinearBase(torch.nn.Module):
175
+ """Base linear layer.
176
+
177
+ Args:
178
+ input_size: input dimension of the linear layer.
179
+ output_size: output dimension of the linear layer.
180
+ bias: If true, add bias.
181
+ skip_bias_add: If true, skip adding bias but instead return it.
182
+ params_dtype: Data type for the parameters.
183
+ quant_config: Quantization configure.
184
+ """
185
+
186
+ def __init__(
187
+ self,
188
+ input_size: int,
189
+ output_size: int,
190
+ skip_bias_add: bool = False,
191
+ params_dtype: Optional[torch.dtype] = None,
192
+ quant_config: Optional[QuantizationConfig] = None,
193
+ prefix: str = "",
194
+ ):
195
+ super().__init__()
196
+
197
+ # Keep input parameters
198
+ self.input_size = input_size
199
+ self.output_size = output_size
200
+ self.skip_bias_add = skip_bias_add
201
+ if params_dtype is None:
202
+ params_dtype = torch.get_default_dtype()
203
+ self.params_dtype = params_dtype
204
+ if quant_config is None:
205
+ self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod()
206
+ else:
207
+ self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
208
+
209
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
210
+ raise NotImplementedError
211
+
212
+
230
213
  class ReplicatedLinear(LinearBase):
231
214
  """Replicated linear layer.
232
215
 
@@ -426,9 +409,7 @@ class ColumnParallelLinear(LinearBase):
426
409
  if len(loaded_weight.shape) == 0:
427
410
  loaded_weight = loaded_weight.reshape(1)
428
411
 
429
- assert (
430
- param_data.shape == loaded_weight.shape
431
- ), f"{param_data.shape=}, {loaded_weight.shape=}"
412
+ assert param_data.shape == loaded_weight.shape
432
413
  param_data.copy_(loaded_weight)
433
414
 
434
415
  def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
@@ -437,7 +418,7 @@ class ColumnParallelLinear(LinearBase):
437
418
  if len(loaded_weight.shape) == 0:
438
419
  assert loaded_weight.numel() == 1
439
420
  loaded_weight = loaded_weight.reshape(1)
440
- param.load_column_parallel_weight(loaded_weight=loaded_weight)
421
+ param.load_column_parallel_weight(loaded_weight, tp_rank=self.tp_rank)
441
422
 
442
423
  def forward(self, input_):
443
424
  bias = self.bias if not self.skip_bias_add else None
@@ -565,9 +546,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
565
546
  param_data, loaded_weight, 0
566
547
  )
567
548
 
568
- assert (
569
- param_data.shape == loaded_weight.shape
570
- ), f"{param_data.shape=}, {loaded_weight.shape=}"
549
+ assert param_data.shape == loaded_weight.shape
571
550
  param_data.copy_(loaded_weight)
572
551
  return
573
552
  current_shard_offset = 0
@@ -643,9 +622,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
643
622
  "the same for all partitions."
644
623
  )
645
624
 
646
- assert (
647
- param_data.shape == loaded_weight.shape
648
- ), f"{param_data.shape=}, {loaded_weight.shape=}"
625
+ assert param_data.shape == loaded_weight.shape
649
626
  param_data.copy_(loaded_weight)
650
627
 
651
628
  def _load_fused_module_from_checkpoint(
@@ -697,6 +674,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
697
674
  elif type(param) in (RowvLLMParameter, BasevLLMParameter):
698
675
  param.load_merged_column_weight(loaded_weight=loaded_weight)
699
676
  return
677
+ # TODO: @dsikka - move to parameter.py
700
678
  self._load_fused_module_from_checkpoint(param, loaded_weight)
701
679
  return
702
680
 
@@ -882,6 +860,7 @@ class QKVParallelLinear(ColumnParallelLinear):
882
860
  elif type(param) in (RowvLLMParameter, BasevLLMParameter):
883
861
  param.load_qkv_weight(loaded_weight=loaded_weight)
884
862
  return
863
+ # TODO: @dsikka - move to parameter.py
885
864
  self._load_fused_module_from_checkpoint(param, loaded_weight)
886
865
  return
887
866
 
@@ -896,24 +875,14 @@ class QKVParallelLinear(ColumnParallelLinear):
896
875
  shard_offset = (shard_offset + block_n - 1) // block_n
897
876
  shard_size = (shard_size + block_n - 1) // block_n
898
877
 
899
- if isinstance(param, _ColumnvLLMParameter):
900
- load_column_qkv_weight(
901
- param,
902
- loaded_weight,
903
- num_heads=self.num_kv_head_replicas,
904
- shard_id=loaded_shard_id,
905
- shard_offset=shard_offset,
906
- shard_size=shard_size,
907
- tp_rank=self.tp_rank,
908
- )
909
- else:
910
- param.load_qkv_weight(
911
- loaded_weight=loaded_weight,
912
- num_heads=self.num_kv_head_replicas,
913
- shard_id=loaded_shard_id,
914
- shard_offset=shard_offset,
915
- shard_size=shard_size,
916
- )
878
+ param.load_qkv_weight(
879
+ loaded_weight=loaded_weight,
880
+ num_heads=self.num_kv_head_replicas,
881
+ shard_id=loaded_shard_id,
882
+ shard_offset=shard_offset,
883
+ shard_size=shard_size,
884
+ tp_rank=self.tp_rank,
885
+ )
917
886
 
918
887
  def weight_loader(
919
888
  self,
@@ -962,9 +931,7 @@ class QKVParallelLinear(ColumnParallelLinear):
962
931
  param_data, loaded_weight, 0
963
932
  )
964
933
 
965
- assert (
966
- param_data.shape == loaded_weight.shape
967
- ), f"{param_data.shape=}, {loaded_weight.shape=}"
934
+ assert param_data.shape == loaded_weight.shape
968
935
  param_data.copy_(loaded_weight)
969
936
  return
970
937
  shard_offsets = [
@@ -1105,9 +1072,7 @@ class QKVParallelLinear(ColumnParallelLinear):
1105
1072
  "for all partitions."
1106
1073
  )
1107
1074
 
1108
- assert (
1109
- param_data.shape == loaded_weight.shape
1110
- ), f"{param_data.shape=}, {loaded_weight.shape=}"
1075
+ assert param_data.shape == loaded_weight.shape
1111
1076
  param_data.copy_(loaded_weight)
1112
1077
 
1113
1078
 
@@ -1234,9 +1199,7 @@ class RowParallelLinear(LinearBase):
1234
1199
  if len(loaded_weight.shape) == 0:
1235
1200
  loaded_weight = loaded_weight.reshape(1)
1236
1201
 
1237
- assert (
1238
- param_data.shape == loaded_weight.shape
1239
- ), f"{param_data.shape=}, {loaded_weight.shape=}"
1202
+ assert param_data.shape == loaded_weight.shape
1240
1203
  param_data.copy_(loaded_weight)
1241
1204
 
1242
1205
  def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
@@ -1247,7 +1210,18 @@ class RowParallelLinear(LinearBase):
1247
1210
  assert loaded_weight.numel() == 1
1248
1211
  loaded_weight = loaded_weight.reshape(1)
1249
1212
 
1250
- param.load_row_parallel_weight(loaded_weight=loaded_weight)
1213
+ if isinstance(param, BasevLLMParameter):
1214
+ # This `BasevLLMParameter` is defined in sglang/srt/layers/parameter.py,
1215
+ # It supports additional parameters like tp_rank and use_presharded_weights.
1216
+ param.load_row_parallel_weight(
1217
+ loaded_weight,
1218
+ tp_rank=self.tp_rank,
1219
+ use_presharded_weights=self.use_presharded_weights,
1220
+ )
1221
+ else:
1222
+ # `params` is defined in `vllm/model_executor/parameter.py`,
1223
+ # It does not support additional parameters.
1224
+ param.load_row_parallel_weight(loaded_weight)
1251
1225
 
1252
1226
  def forward(self, input_):
1253
1227
  if self.input_is_parallel:
@@ -14,17 +14,18 @@
14
14
  """Logits processing."""
15
15
 
16
16
  import dataclasses
17
+ import logging
17
18
  from typing import List, Optional, Union
18
19
 
19
20
  import torch
20
21
  import triton
21
22
  import triton.language as tl
22
23
  from torch import nn
23
- from vllm.distributed import (
24
+
25
+ from sglang.srt.distributed import (
24
26
  get_tensor_model_parallel_world_size,
25
27
  tensor_model_parallel_all_gather,
26
28
  )
27
-
28
29
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
29
30
  from sglang.srt.model_executor.forward_batch_info import (
30
31
  CaptureHiddenMode,
@@ -32,6 +33,8 @@ from sglang.srt.model_executor.forward_batch_info import (
32
33
  ForwardMode,
33
34
  )
34
35
 
36
+ logger = logging.getLogger(__name__)
37
+
35
38
 
36
39
  @dataclasses.dataclass
37
40
  class LogitsProcessorOutput:
@@ -50,8 +53,6 @@ class LogitsProcessorOutput:
50
53
  next_token_top_logprobs_idx: Optional[List] = None
51
54
 
52
55
  ## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
53
- # The normlaized logprobs of prompts. shape: [#seq]
54
- normalized_prompt_logprobs: torch.Tensor = None
55
56
  # The logprobs of input tokens. shape: [#token]
56
57
  input_token_logprobs: torch.Tensor = None
57
58
  # The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
@@ -129,59 +130,70 @@ class LogitsProcessor(nn.Module):
129
130
  hidden_states,
130
131
  lm_head: VocabParallelEmbedding,
131
132
  logits_metadata: Union[LogitsMetadata, ForwardBatch],
132
- ):
133
+ ) -> LogitsProcessorOutput:
133
134
  if isinstance(logits_metadata, ForwardBatch):
134
135
  logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
135
136
 
136
137
  # Get the last hidden states and last logits for the next token prediction
137
138
  if (
138
- logits_metadata.forward_mode.is_decode()
139
+ logits_metadata.forward_mode.is_decode_or_idle()
139
140
  or logits_metadata.forward_mode.is_target_verify()
140
141
  ):
141
- last_index = None
142
- last_hidden = hidden_states
143
- else:
142
+ pruned_states = hidden_states
143
+ sample_indices = None
144
+ elif (
145
+ logits_metadata.forward_mode.is_extend()
146
+ and not logits_metadata.extend_return_logprob
147
+ ):
148
+ # Prefill without input logprobs.
144
149
  last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
145
- last_hidden = hidden_states[last_index]
150
+ pruned_states = hidden_states[last_index]
151
+ sample_indices = None
152
+ else:
153
+ # Slice the requested tokens to compute logprob
154
+ sample_index_pt = -1
155
+ sample_indices = []
156
+ pt, pruned_states, pruned_input_ids = 0, [], []
157
+ for start_len, extend_len in zip(
158
+ logits_metadata.extend_logprob_start_lens_cpu,
159
+ logits_metadata.extend_seq_lens_cpu,
160
+ ):
161
+ pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
162
+ sample_index_pt += extend_len - start_len
163
+ sample_indices.append(sample_index_pt)
164
+ pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
165
+ pt += extend_len
166
+
167
+ pruned_states = torch.cat(pruned_states)
168
+
169
+ # Compute logits for both input and sampled tokens.
170
+ logits = self._get_logits(pruned_states, lm_head, logits_metadata)
171
+ sampled_logits = (
172
+ logits[sample_indices] if sample_indices is not None else logits
173
+ )
146
174
 
147
- # Compute logits
148
- last_logits = self._get_logits(last_hidden, lm_head)
149
175
  if (
150
176
  not logits_metadata.extend_return_logprob
151
177
  or logits_metadata.capture_hidden_mode.need_capture()
152
178
  ):
153
179
  # Decode mode or extend mode without return_logprob.
154
180
  return LogitsProcessorOutput(
155
- next_token_logits=last_logits,
181
+ next_token_logits=sampled_logits,
156
182
  hidden_states=(
157
183
  hidden_states
158
184
  if logits_metadata.capture_hidden_mode.is_full()
159
185
  else (
160
- last_hidden
186
+ pruned_states
161
187
  if logits_metadata.capture_hidden_mode.is_last()
162
188
  else None
163
189
  )
164
190
  ),
165
191
  )
166
192
  else:
167
- # Slice the requested tokens to compute logprob
168
- pt, pruned_states, pruned_input_ids = 0, [], []
169
- for start_len, extend_len in zip(
170
- logits_metadata.extend_logprob_start_lens_cpu,
171
- logits_metadata.extend_seq_lens_cpu,
172
- ):
173
- pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
174
- pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
175
- pt += extend_len
176
-
177
- # Compute the logits of all required tokens
178
- pruned_states = torch.cat(pruned_states)
179
- del hidden_states
180
- input_token_logits = self._get_logits(pruned_states, lm_head)
181
- del pruned_states
193
+ input_logprobs = logits
194
+ del hidden_states, logits
182
195
 
183
196
  # Normalize the logprob w/o temperature, top-p
184
- input_logprobs = input_token_logits
185
197
  input_logprobs = self.compute_temp_top_p_normalized_logprobs(
186
198
  input_logprobs, logits_metadata
187
199
  )
@@ -195,25 +207,18 @@ class LogitsProcessor(nn.Module):
195
207
  else:
196
208
  input_top_logprobs_val = input_top_logprobs_idx = None
197
209
 
198
- # Compute the normalized logprobs for the requested tokens.
199
- # Note that we pad a zero at the end for easy batching.
200
210
  input_token_logprobs = input_logprobs[
201
- torch.arange(input_logprobs.shape[0], device="cuda"),
211
+ torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
202
212
  torch.cat(
203
213
  [
204
214
  torch.cat(pruned_input_ids)[1:],
205
- torch.tensor([0], device="cuda"),
215
+ torch.tensor([0], device=input_logprobs.device),
206
216
  ]
207
217
  ),
208
218
  ]
209
- normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
210
- input_token_logprobs,
211
- logits_metadata,
212
- )
213
219
 
214
220
  return LogitsProcessorOutput(
215
- next_token_logits=last_logits,
216
- normalized_prompt_logprobs=normalized_prompt_logprobs,
221
+ next_token_logits=sampled_logits,
217
222
  input_token_logprobs=input_token_logprobs,
218
223
  input_top_logprobs_val=input_top_logprobs_val,
219
224
  input_top_logprobs_idx=input_top_logprobs_idx,
@@ -223,8 +228,11 @@ class LogitsProcessor(nn.Module):
223
228
  self,
224
229
  hidden_states: torch.Tensor,
225
230
  lm_head: VocabParallelEmbedding,
231
+ logits_metadata: LogitsMetadata,
226
232
  embedding_bias: Optional[torch.Tensor] = None,
227
233
  ) -> torch.Tensor:
234
+ """Get logits from hidden_states."""
235
+
228
236
  if hasattr(lm_head, "weight"):
229
237
  logits = torch.matmul(hidden_states, lm_head.weight.T)
230
238
  else:
@@ -237,8 +245,6 @@ class LogitsProcessor(nn.Module):
237
245
  if self.do_tensor_parallel_all_gather:
238
246
  logits = tensor_model_parallel_all_gather(logits)
239
247
 
240
- # Compute the normalized logprobs for the requested tokens.
241
- # Note that we pad a zero at the end for easy batching.
242
248
  logits = logits[:, : self.config.vocab_size].float()
243
249
 
244
250
  if self.final_logit_softcapping:
@@ -246,27 +252,6 @@ class LogitsProcessor(nn.Module):
246
252
 
247
253
  return logits
248
254
 
249
- @staticmethod
250
- def _get_normalized_prompt_logprobs(
251
- input_token_logprobs: torch.Tensor,
252
- logits_metadata: LogitsMetadata,
253
- ):
254
- logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
255
- pruned_lens = torch.tensor(
256
- logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
257
- )
258
-
259
- start = torch.zeros_like(pruned_lens)
260
- start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
261
- end = torch.clamp(
262
- start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
263
- )
264
- sum_logp = (
265
- logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
266
- )
267
- normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
268
- return normalized_prompt_logprobs
269
-
270
255
  @staticmethod
271
256
  def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
272
257
  max_k = max(logits_metadata.top_logprobs_nums)
@@ -4,13 +4,12 @@ from typing import Callable, List, Optional, Tuple
4
4
  import torch
5
5
  from torch.nn import Module
6
6
  from vllm import _custom_ops as ops
7
- from vllm.distributed import (
7
+ from vllm.model_executor.custom_op import CustomOp
8
+
9
+ from sglang.srt.distributed import (
8
10
  get_tensor_model_parallel_rank,
9
11
  get_tensor_model_parallel_world_size,
10
12
  )
11
- from vllm.model_executor.custom_op import CustomOp
12
- from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
13
-
14
13
  from sglang.srt.layers.custom_op_util import register_custom_op
15
14
  from sglang.srt.layers.moe.ep_moe.kernels import (
16
15
  grouped_gemm_triton,
@@ -25,6 +24,7 @@ from sglang.srt.layers.quantization.base_config import (
25
24
  QuantizationConfig,
26
25
  QuantizeMethodBase,
27
26
  )
27
+ from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
28
28
  from sglang.srt.utils import is_hip, set_weight_attrs
29
29
 
30
30
  logger = logging.getLogger(__name__)
@@ -8,6 +8,7 @@ from typing import Callable, Optional
8
8
  import torch
9
9
  from torch.nn import functional as F
10
10
 
11
+ from sglang.srt.layers.activation import SiluAndMul
11
12
  from sglang.srt.layers.moe.topk import select_experts
12
13
 
13
14
 
@@ -44,3 +45,71 @@ def fused_moe_forward_native(
44
45
  x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
45
46
  expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
46
47
  return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
48
+
49
+
50
+ def moe_forward_native(
51
+ layer: torch.nn.Module,
52
+ x: torch.Tensor,
53
+ use_grouped_topk: bool,
54
+ top_k: int,
55
+ router_logits: torch.Tensor,
56
+ renormalize: bool,
57
+ topk_group: Optional[int] = None,
58
+ num_expert_group: Optional[int] = None,
59
+ custom_routing_function: Optional[Callable] = None,
60
+ correction_bias: Optional[torch.Tensor] = None,
61
+ ) -> torch.Tensor:
62
+
63
+ topk_weights, topk_ids = select_experts(
64
+ hidden_states=x,
65
+ router_logits=router_logits,
66
+ use_grouped_topk=use_grouped_topk,
67
+ top_k=top_k,
68
+ renormalize=renormalize,
69
+ topk_group=topk_group,
70
+ num_expert_group=num_expert_group,
71
+ custom_routing_function=custom_routing_function,
72
+ correction_bias=correction_bias,
73
+ torch_native=True,
74
+ )
75
+
76
+ # Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589
77
+ len_experts = layer.num_experts
78
+
79
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts))
80
+ cnts.scatter_(1, topk_ids.to(torch.int64), 1)
81
+ tokens_per_expert = cnts.sum(dim=0)
82
+ idxs = topk_ids.view(-1).argsort()
83
+
84
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
85
+ tokens_per_expert = tokens_per_expert.cpu().numpy()
86
+
87
+ outputs = []
88
+ start_idx = 0
89
+ for i, num_tokens in enumerate(tokens_per_expert):
90
+ end_idx = start_idx + num_tokens
91
+ if num_tokens == 0:
92
+ continue
93
+ tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
94
+
95
+ layer_w13_weight = layer.w13_weight[i]
96
+ layer_w2_weight = layer.w2_weight[i]
97
+
98
+ gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
99
+ gate_up = SiluAndMul()(gate_up)
100
+ expert_out = F.linear(gate_up, layer_w2_weight)
101
+ outputs.append(expert_out)
102
+ start_idx = end_idx
103
+
104
+ outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
105
+ new_x = torch.empty_like(outs)
106
+
107
+ new_x[idxs] = outs
108
+ final_out = (
109
+ new_x.view(*topk_ids.shape, -1)
110
+ .type(topk_weights.dtype)
111
+ .mul_(topk_weights.unsqueeze(dim=-1))
112
+ .sum(dim=1)
113
+ .type(new_x.dtype)
114
+ )
115
+ return final_out
@@ -15,15 +15,18 @@ from vllm import _custom_ops as ops
15
15
 
16
16
  from sglang.srt.layers.moe.topk import select_experts
17
17
  from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
18
- from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
18
+ from sglang.srt.utils import (
19
+ direct_register_custom_op,
20
+ get_device_name,
21
+ is_cuda_available,
22
+ is_hip,
23
+ )
19
24
 
20
- is_hip_flag = False
21
- if not is_hip():
25
+ is_cuda = is_cuda_available()
26
+ is_hip_flag = is_hip()
27
+ if is_cuda:
22
28
  from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
23
29
 
24
- is_hip_flag = False
25
- else:
26
- is_hip_flag = True
27
30
 
28
31
  logger = logging.getLogger(__name__)
29
32
  padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0