sglang 0.4.0__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_offline_throughput.py +18 -6
  3. sglang/bench_one_batch.py +13 -0
  4. sglang/bench_serving.py +8 -1
  5. sglang/check_env.py +140 -48
  6. sglang/lang/backend/runtime_endpoint.py +1 -0
  7. sglang/lang/chat_template.py +32 -0
  8. sglang/llama3_eval.py +316 -0
  9. sglang/srt/constrained/outlines_backend.py +5 -0
  10. sglang/srt/constrained/xgrammar_backend.py +9 -6
  11. sglang/srt/layers/attention/__init__.py +5 -2
  12. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  13. sglang/srt/layers/attention/flashinfer_backend.py +22 -5
  14. sglang/srt/layers/attention/torch_native_backend.py +22 -8
  15. sglang/srt/layers/attention/triton_backend.py +38 -33
  16. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  17. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  18. sglang/srt/layers/ep_moe/__init__.py +0 -0
  19. sglang/srt/layers/ep_moe/kernels.py +349 -0
  20. sglang/srt/layers/ep_moe/layer.py +665 -0
  21. sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
  22. sglang/srt/layers/fused_moe_triton/layer.py +1 -1
  23. sglang/srt/layers/logits_processor.py +133 -95
  24. sglang/srt/layers/quantization/__init__.py +2 -47
  25. sglang/srt/layers/quantization/fp8.py +607 -0
  26. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  27. sglang/srt/layers/radix_attention.py +11 -2
  28. sglang/srt/layers/sampler.py +29 -5
  29. sglang/srt/layers/torchao_utils.py +58 -45
  30. sglang/srt/managers/detokenizer_manager.py +37 -17
  31. sglang/srt/managers/io_struct.py +39 -10
  32. sglang/srt/managers/schedule_batch.py +39 -24
  33. sglang/srt/managers/schedule_policy.py +64 -5
  34. sglang/srt/managers/scheduler.py +236 -197
  35. sglang/srt/managers/tokenizer_manager.py +99 -58
  36. sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
  37. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  38. sglang/srt/mem_cache/chunk_cache.py +2 -2
  39. sglang/srt/mem_cache/memory_pool.py +5 -1
  40. sglang/srt/mem_cache/radix_cache.py +12 -2
  41. sglang/srt/model_executor/cuda_graph_runner.py +39 -11
  42. sglang/srt/model_executor/model_runner.py +24 -9
  43. sglang/srt/model_parallel.py +67 -10
  44. sglang/srt/models/commandr.py +2 -2
  45. sglang/srt/models/deepseek_v2.py +87 -7
  46. sglang/srt/models/gemma2.py +34 -0
  47. sglang/srt/models/gemma2_reward.py +0 -1
  48. sglang/srt/models/granite.py +517 -0
  49. sglang/srt/models/grok.py +72 -13
  50. sglang/srt/models/llama.py +22 -5
  51. sglang/srt/models/llama_classification.py +11 -23
  52. sglang/srt/models/llama_reward.py +0 -2
  53. sglang/srt/models/llava.py +37 -14
  54. sglang/srt/models/mixtral.py +12 -9
  55. sglang/srt/models/phi3_small.py +0 -5
  56. sglang/srt/models/qwen2.py +20 -0
  57. sglang/srt/models/qwen2_moe.py +0 -5
  58. sglang/srt/models/torch_native_llama.py +0 -5
  59. sglang/srt/openai_api/adapter.py +4 -0
  60. sglang/srt/openai_api/protocol.py +9 -4
  61. sglang/srt/sampling/sampling_batch_info.py +9 -8
  62. sglang/srt/server.py +4 -4
  63. sglang/srt/server_args.py +62 -13
  64. sglang/srt/utils.py +57 -10
  65. sglang/test/test_utils.py +3 -2
  66. sglang/utils.py +10 -3
  67. sglang/version.py +1 -1
  68. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +15 -9
  69. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +72 -65
  70. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
  71. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
  72. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ from vllm import _custom_ops as ops
16
16
  from sglang.srt.utils import direct_register_custom_op, get_device_name
17
17
 
18
18
  logger = logging.getLogger(__name__)
19
+ padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
19
20
 
20
21
 
21
22
  @triton.jit
@@ -58,6 +59,7 @@ def fused_moe_kernel(
58
59
  compute_type: tl.constexpr,
59
60
  use_fp8_w8a8: tl.constexpr,
60
61
  use_int8_w8a16: tl.constexpr,
62
+ even_Ks: tl.constexpr,
61
63
  ):
62
64
  """
63
65
  Implements the fused computation for a Mixture of Experts (MOE) using
@@ -143,12 +145,21 @@ def fused_moe_kernel(
143
145
  for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
144
146
  # Load the next block of A and B, generate a mask by checking the
145
147
  # K dimension.
146
- a = tl.load(
147
- a_ptrs,
148
- mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
149
- other=0.0,
150
- )
151
- b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
148
+ if even_Ks:
149
+ a = tl.load(
150
+ a_ptrs,
151
+ mask=token_mask[:, None],
152
+ other=0.0,
153
+ )
154
+ b = tl.load(b_ptrs)
155
+ else:
156
+ a = tl.load(
157
+ a_ptrs,
158
+ mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
159
+ other=0.0,
160
+ )
161
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
162
+
152
163
  # We accumulate along the K dimension.
153
164
  if use_int8_w8a16:
154
165
  accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
@@ -254,7 +265,9 @@ def invoke_fused_moe_kernel(
254
265
  assert topk_weights.stride(1) == 1
255
266
  assert sorted_token_ids.stride(0) == 1
256
267
 
268
+ padded_size = 0
257
269
  if use_fp8_w8a8:
270
+ padded_size = padding_size
258
271
  A, A_scale = ops.scaled_fp8_quant(A, A_scale)
259
272
  assert B_scale is not None
260
273
  elif use_int8_w8a16:
@@ -268,6 +281,12 @@ def invoke_fused_moe_kernel(
268
281
  * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
269
282
  )
270
283
 
284
+ K = B.shape[2] - padded_size
285
+ if K % config["BLOCK_SIZE_K"] == 0:
286
+ even_Ks = True
287
+ else:
288
+ even_Ks = False
289
+
271
290
  fused_moe_kernel[grid](
272
291
  A,
273
292
  B,
@@ -279,7 +298,7 @@ def invoke_fused_moe_kernel(
279
298
  expert_ids,
280
299
  num_tokens_post_padded,
281
300
  B.shape[1],
282
- B.shape[2],
301
+ B.shape[2] - padded_size,
283
302
  sorted_token_ids.shape[0],
284
303
  topk_ids.numel(),
285
304
  A.stride(0),
@@ -296,6 +315,7 @@ def invoke_fused_moe_kernel(
296
315
  compute_type=compute_type,
297
316
  use_fp8_w8a8=use_fp8_w8a8,
298
317
  use_int8_w8a16=use_int8_w8a16,
318
+ even_Ks=even_Ks,
299
319
  **config,
300
320
  )
301
321
 
@@ -351,20 +371,39 @@ def get_default_config(
351
371
  dtype: Optional[str],
352
372
  is_marlin: bool,
353
373
  ) -> Dict[str, int]:
354
- config = {
355
- "BLOCK_SIZE_M": 64,
356
- "BLOCK_SIZE_N": 64,
357
- "BLOCK_SIZE_K": 32,
358
- "GROUP_SIZE_M": 8,
359
- }
360
- # A heuristic: fused marlin works faster with this config for small M
361
- if M <= E or (is_marlin and M <= 32):
374
+ if dtype == "fp8_w8a8":
362
375
  config = {
363
- "BLOCK_SIZE_M": 16,
364
- "BLOCK_SIZE_N": 32,
365
- "BLOCK_SIZE_K": 64,
366
- "GROUP_SIZE_M": 1,
376
+ "BLOCK_SIZE_M": 128,
377
+ "BLOCK_SIZE_N": 256,
378
+ "BLOCK_SIZE_K": 128,
379
+ "GROUP_SIZE_M": 32,
380
+ "num_warps": 8,
381
+ "num_stages": 4,
367
382
  }
383
+ if M <= E:
384
+ config = {
385
+ "BLOCK_SIZE_M": 64,
386
+ "BLOCK_SIZE_N": 128,
387
+ "BLOCK_SIZE_K": 128,
388
+ "GROUP_SIZE_M": 1,
389
+ "num_warps": 4,
390
+ "num_stages": 4,
391
+ }
392
+ else:
393
+ config = {
394
+ "BLOCK_SIZE_M": 64,
395
+ "BLOCK_SIZE_N": 64,
396
+ "BLOCK_SIZE_K": 32,
397
+ "GROUP_SIZE_M": 8,
398
+ }
399
+ # A heuristic: fused marlin works faster with this config for small M
400
+ if M <= E or (is_marlin and M <= 32):
401
+ config = {
402
+ "BLOCK_SIZE_M": 16,
403
+ "BLOCK_SIZE_N": 32,
404
+ "BLOCK_SIZE_K": 64,
405
+ "GROUP_SIZE_M": 1,
406
+ }
368
407
  return config
369
408
 
370
409
 
@@ -645,8 +684,12 @@ def fused_experts_impl(
645
684
  a1_scale: Optional[torch.Tensor] = None,
646
685
  a2_scale: Optional[torch.Tensor] = None,
647
686
  ):
687
+ padded_size = padding_size
688
+ if not use_fp8_w8a8:
689
+ padded_size = 0
690
+
648
691
  # Check constraints.
649
- assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
692
+ assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch"
650
693
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
651
694
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
652
695
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
@@ -668,7 +711,7 @@ def fused_experts_impl(
668
711
  get_config_func = functools.partial(
669
712
  try_get_optimal_moe_config,
670
713
  w1.shape,
671
- w2.shape,
714
+ (w2.shape[0], w2.shape[1], w2.shape[2] - padded_size),
672
715
  topk_ids.shape[1],
673
716
  config_dtype,
674
717
  )
@@ -19,7 +19,7 @@ from sglang.srt.layers.quantization.base_config import (
19
19
  )
20
20
  from sglang.srt.utils import set_weight_attrs
21
21
 
22
- if torch.cuda.is_available() or torch.hip.is_available():
22
+ if torch.cuda.is_available():
23
23
  from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
24
24
  else:
25
25
  fused_experts = None # type: ignore
@@ -39,10 +39,12 @@ class LogitsProcessorOutput:
39
39
  # The logprobs of input tokens. shape: [#token, vocab_size]
40
40
  input_token_logprobs: torch.Tensor = None
41
41
 
42
- # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
43
- input_top_logprobs: List = None
44
- # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
45
- output_top_logprobs: List = None
42
+ # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k]
43
+ input_top_logprobs_val: List = None
44
+ input_top_logprobs_idx: List = None
45
+ # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k]
46
+ output_top_logprobs_val: List = None
47
+ output_top_logprobs_idx: List = None
46
48
 
47
49
 
48
50
  @dataclasses.dataclass
@@ -89,76 +91,18 @@ class LogitsMetadata:
89
91
 
90
92
 
91
93
  class LogitsProcessor(nn.Module):
92
- def __init__(self, config, skip_all_gather: bool = False):
94
+ def __init__(
95
+ self, config, skip_all_gather: bool = False, logit_scale: Optional[float] = None
96
+ ):
93
97
  super().__init__()
94
98
  self.config = config
99
+ self.logit_scale = logit_scale
95
100
  self.do_tensor_parallel_all_gather = (
96
101
  not skip_all_gather and get_tensor_model_parallel_world_size() > 1
97
102
  )
98
-
99
- def _get_normalized_prompt_logprobs(
100
- self,
101
- input_token_logprobs: torch.Tensor,
102
- logits_metadata: LogitsMetadata,
103
- ):
104
- logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
105
- pruned_lens = torch.tensor(
106
- logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
107
- )
108
-
109
- start = torch.zeros_like(pruned_lens)
110
- start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
111
- end = torch.clamp(
112
- start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
113
- )
114
- sum_logp = (
115
- logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
103
+ self.final_logit_softcapping = getattr(
104
+ self.config, "final_logit_softcapping", None
116
105
  )
117
- normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
118
- return normalized_prompt_logprobs
119
-
120
- @staticmethod
121
- def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
122
- max_k = max(logits_metadata.top_logprobs_nums)
123
- ret = all_logprobs.topk(max_k, dim=1)
124
- values = ret.values.tolist()
125
- indices = ret.indices.tolist()
126
-
127
- if logits_metadata.forward_mode.is_decode():
128
- output_top_logprobs = []
129
- for i, k in enumerate(logits_metadata.top_logprobs_nums):
130
- output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k])))
131
- return None, output_top_logprobs
132
- else:
133
- input_top_logprobs, output_top_logprobs = [], []
134
-
135
- pt = 0
136
- for k, pruned_len in zip(
137
- logits_metadata.top_logprobs_nums,
138
- logits_metadata.extend_logprob_pruned_lens_cpu,
139
- ):
140
- if pruned_len <= 0:
141
- input_top_logprobs.append([])
142
- output_top_logprobs.append([])
143
- continue
144
-
145
- input_top_logprobs.append(
146
- [
147
- list(zip(values[pt + j][:k], indices[pt + j][:k]))
148
- for j in range(pruned_len - 1)
149
- ]
150
- )
151
- output_top_logprobs.append(
152
- list(
153
- zip(
154
- values[pt + pruned_len - 1][:k],
155
- indices[pt + pruned_len - 1][:k],
156
- )
157
- )
158
- )
159
- pt += pruned_len
160
-
161
- return input_top_logprobs, output_top_logprobs
162
106
 
163
107
  def forward(
164
108
  self,
@@ -184,38 +128,33 @@ class LogitsProcessor(nn.Module):
184
128
  last_logits = tensor_model_parallel_all_gather(last_logits)
185
129
  last_logits = last_logits[:, : self.config.vocab_size].float()
186
130
 
187
- if hasattr(self.config, "final_logit_softcapping"):
188
- last_logits.div_(self.config.final_logit_softcapping)
131
+ if self.final_logit_softcapping:
132
+ last_logits.div_(self.final_logit_softcapping)
189
133
  torch.tanh(last_logits, out=last_logits)
190
- last_logits.mul_(self.config.final_logit_softcapping)
134
+ last_logits.mul_(self.final_logit_softcapping)
191
135
 
192
136
  # Return only last_logits if logprob is not requested
193
137
  if not logits_metadata.return_logprob:
194
138
  return LogitsProcessorOutput(
195
139
  next_token_logits=last_logits,
196
- next_token_logprobs=None,
197
- normalized_prompt_logprobs=None,
198
- input_token_logprobs=None,
199
- input_top_logprobs=None,
200
- output_top_logprobs=None,
201
140
  )
202
141
  else:
203
- last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
142
+ last_logprobs = self.compute_temp_top_p_normalized_logprobs(
143
+ last_logits, logits_metadata
144
+ )
204
145
 
205
146
  if logits_metadata.forward_mode.is_decode():
206
147
  if logits_metadata.return_top_logprob:
207
- output_top_logprobs = self.get_top_logprobs(
208
- last_logprobs, logits_metadata
209
- )[1]
148
+ output_top_logprobs_val, output_top_logprobs_idx = (
149
+ self.get_top_logprobs(last_logprobs, logits_metadata)[2:4]
150
+ )
210
151
  else:
211
- output_top_logprobs = None
152
+ output_top_logprobs_val = output_top_logprobs_idx = None
212
153
  return LogitsProcessorOutput(
213
154
  next_token_logits=last_logits,
214
155
  next_token_logprobs=last_logprobs,
215
- normalized_prompt_logprobs=None,
216
- input_token_logprobs=None,
217
- input_top_logprobs=None,
218
- output_top_logprobs=output_top_logprobs,
156
+ output_top_logprobs_val=output_top_logprobs_val,
157
+ output_top_logprobs_idx=output_top_logprobs_idx,
219
158
  )
220
159
  else:
221
160
  # Slice the requested tokens to compute logprob
@@ -233,24 +172,35 @@ class LogitsProcessor(nn.Module):
233
172
  all_logits = self._get_logits(states, lm_head)
234
173
  if self.do_tensor_parallel_all_gather:
235
174
  all_logits = tensor_model_parallel_all_gather(all_logits)
175
+
176
+ # The LM head's weights may be zero-padded for parallelism. Remove any
177
+ # extra logits that this padding may have produced.
236
178
  all_logits = all_logits[:, : self.config.vocab_size].float()
237
179
 
238
- if hasattr(self.config, "final_logit_softcapping"):
239
- all_logits.div_(self.config.final_logit_softcapping)
180
+ if self.final_logit_softcapping:
181
+ all_logits.div_(self.final_logit_softcapping)
240
182
  torch.tanh(all_logits, out=all_logits)
241
- all_logits.mul_(self.config.final_logit_softcapping)
183
+ all_logits.mul_(self.final_logit_softcapping)
242
184
 
243
185
  all_logprobs = all_logits
244
186
  del all_logits, hidden_states
245
- all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
187
+
188
+ all_logprobs = self.compute_temp_top_p_normalized_logprobs(
189
+ all_logprobs, logits_metadata
190
+ )
246
191
 
247
192
  # Get the logprob of top-k tokens
248
193
  if logits_metadata.return_top_logprob:
249
- input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
250
- all_logprobs, logits_metadata
251
- )
194
+ (
195
+ input_top_logprobs_val,
196
+ input_top_logprobs_idx,
197
+ output_top_logprobs_val,
198
+ output_top_logprobs_idx,
199
+ ) = self.get_top_logprobs(all_logprobs, logits_metadata)
252
200
  else:
253
- input_top_logprobs = output_top_logprobs = None
201
+ input_top_logprobs_val = input_top_logprobs_idx = (
202
+ output_top_logprobs_val
203
+ ) = output_top_logprobs_idx = None
254
204
 
255
205
  # Compute the normalized logprobs for the requested tokens.
256
206
  # Note that we pad a zero at the end for easy batching.
@@ -273,8 +223,10 @@ class LogitsProcessor(nn.Module):
273
223
  next_token_logprobs=last_logprobs,
274
224
  normalized_prompt_logprobs=normalized_prompt_logprobs,
275
225
  input_token_logprobs=input_token_logprobs,
276
- input_top_logprobs=input_top_logprobs,
277
- output_top_logprobs=output_top_logprobs,
226
+ input_top_logprobs_val=input_top_logprobs_val,
227
+ input_top_logprobs_idx=input_top_logprobs_idx,
228
+ output_top_logprobs_val=output_top_logprobs_val,
229
+ output_top_logprobs_idx=output_top_logprobs_idx,
278
230
  )
279
231
 
280
232
  def _get_logits(
@@ -288,8 +240,94 @@ class LogitsProcessor(nn.Module):
288
240
  else:
289
241
  # GGUF models
290
242
  logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
243
+
244
+ # Optional scaling factor
245
+ if self.logit_scale is not None:
246
+ logits.mul_(self.logit_scale) # In-place multiply
291
247
  return logits
292
248
 
249
+ @staticmethod
250
+ def _get_normalized_prompt_logprobs(
251
+ input_token_logprobs: torch.Tensor,
252
+ logits_metadata: LogitsMetadata,
253
+ ):
254
+ logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
255
+ pruned_lens = torch.tensor(
256
+ logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
257
+ )
258
+
259
+ start = torch.zeros_like(pruned_lens)
260
+ start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
261
+ end = torch.clamp(
262
+ start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
263
+ )
264
+ sum_logp = (
265
+ logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
266
+ )
267
+ normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
268
+ return normalized_prompt_logprobs
269
+
270
+ @staticmethod
271
+ def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
272
+ max_k = max(logits_metadata.top_logprobs_nums)
273
+ ret = all_logprobs.topk(max_k, dim=1)
274
+ values = ret.values.tolist()
275
+ indices = ret.indices.tolist()
276
+
277
+ if logits_metadata.forward_mode.is_decode():
278
+ output_top_logprobs_val = []
279
+ output_top_logprobs_idx = []
280
+ for i, k in enumerate(logits_metadata.top_logprobs_nums):
281
+ output_top_logprobs_val.append(values[i][:k])
282
+ output_top_logprobs_idx.append(indices[i][:k])
283
+ return None, None, output_top_logprobs_val, output_top_logprobs_idx
284
+ else:
285
+ input_top_logprobs_val, input_top_logprobs_idx = [], []
286
+ output_top_logprobs_val, output_top_logprobs_idx = [], []
287
+
288
+ pt = 0
289
+ for k, pruned_len in zip(
290
+ logits_metadata.top_logprobs_nums,
291
+ logits_metadata.extend_logprob_pruned_lens_cpu,
292
+ ):
293
+ if pruned_len <= 0:
294
+ input_top_logprobs_val.append([])
295
+ input_top_logprobs_idx.append([])
296
+ output_top_logprobs_val.append([])
297
+ output_top_logprobs_idx.append([])
298
+ continue
299
+
300
+ input_top_logprobs_val.append(
301
+ [values[pt + j][:k] for j in range(pruned_len - 1)]
302
+ )
303
+ input_top_logprobs_idx.append(
304
+ [indices[pt + j][:k] for j in range(pruned_len - 1)]
305
+ )
306
+ output_top_logprobs_val.append(
307
+ list(
308
+ values[pt + pruned_len - 1][:k],
309
+ )
310
+ )
311
+ output_top_logprobs_idx.append(
312
+ list(
313
+ indices[pt + pruned_len - 1][:k],
314
+ )
315
+ )
316
+ pt += pruned_len
317
+
318
+ return (
319
+ input_top_logprobs_val,
320
+ input_top_logprobs_idx,
321
+ output_top_logprobs_val,
322
+ output_top_logprobs_idx,
323
+ )
324
+
325
+ @staticmethod
326
+ def compute_temp_top_p_normalized_logprobs(
327
+ last_logits: torch.Tensor, logits_metadata: LogitsMetadata
328
+ ) -> torch.Tensor:
329
+ return torch.nn.functional.log_softmax(last_logits, dim=-1)
330
+
293
331
 
294
332
  def test():
295
333
  all_logprobs = torch.tensor(
@@ -13,7 +13,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
13
13
  from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
14
14
  from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
15
15
  from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
16
- from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
17
16
  from vllm.model_executor.layers.quantization.gguf import GGUFConfig
18
17
  from vllm.model_executor.layers.quantization.gptq import GPTQConfig
19
18
  from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
@@ -23,6 +22,7 @@ from vllm.model_executor.layers.quantization.qqq import QQQConfig
23
22
  from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
24
23
 
25
24
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
25
+ from sglang.srt.layers.quantization.fp8 import Fp8Config
26
26
 
27
27
  QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
28
28
  "aqlm": AQLMConfig,
@@ -53,60 +53,16 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
53
53
  return QUANTIZATION_METHODS[quantization]
54
54
 
55
55
 
56
- def fp8_moe_apply(
57
- self,
58
- layer: torch.nn.Module,
59
- x: torch.Tensor,
60
- router_logits: torch.Tensor,
61
- top_k: int,
62
- renormalize: bool,
63
- use_grouped_topk: bool,
64
- topk_group: Optional[int] = None,
65
- num_expert_group: Optional[int] = None,
66
- custom_routing_function: Optional[Callable] = None,
67
- ) -> torch.Tensor:
68
- """Enhanced apply method for FP8 MoE."""
69
- from sglang.srt.layers.fused_moe_triton import FusedMoE
70
- from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
71
-
72
- # Expert selection
73
- topk_weights, topk_ids = FusedMoE.select_experts(
74
- hidden_states=x,
75
- router_logits=router_logits,
76
- use_grouped_topk=use_grouped_topk,
77
- top_k=top_k,
78
- renormalize=renormalize,
79
- topk_group=topk_group,
80
- num_expert_group=num_expert_group,
81
- custom_routing_function=custom_routing_function,
82
- )
83
-
84
- # Expert fusion with FP8 quantization
85
- return fused_experts(
86
- x,
87
- layer.w13_weight,
88
- layer.w2_weight,
89
- topk_weights=topk_weights,
90
- topk_ids=topk_ids,
91
- inplace=True,
92
- use_fp8_w8a8=True,
93
- w1_scale=layer.w13_weight_scale,
94
- w2_scale=layer.w2_weight_scale,
95
- a1_scale=layer.w13_input_scale,
96
- a2_scale=layer.w2_input_scale,
97
- )
98
-
99
-
100
56
  def fp8_get_quant_method(self, layer, prefix):
101
57
  """Enhanced get_quant_method for FP8 config."""
102
58
  from vllm.model_executor.layers.linear import LinearBase
103
- from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
104
59
  from vllm.model_executor.layers.quantization.utils.quant_utils import (
105
60
  is_layer_skipped,
106
61
  )
107
62
 
108
63
  from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
109
64
  from sglang.srt.layers.linear import UnquantizedLinearMethod
65
+ from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod
110
66
 
111
67
  if isinstance(layer, LinearBase):
112
68
  if is_layer_skipped(prefix, self.ignored_layers):
@@ -151,7 +107,6 @@ def awq_get_quant_method(self, layer, prefix):
151
107
 
152
108
  def apply_monkey_patches():
153
109
  """Apply all monkey patches in one place."""
154
- setattr(Fp8MoEMethod, "apply", fp8_moe_apply)
155
110
  setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
156
111
  setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
157
112
  setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)