sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +0 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  12. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  13. sglang/srt/constrained/xgrammar_backend.py +26 -4
  14. sglang/srt/custom_op.py +0 -62
  15. sglang/srt/disaggregation/decode.py +62 -6
  16. sglang/srt/disaggregation/mini_lb.py +5 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +32 -62
  18. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  19. sglang/srt/disaggregation/prefill.py +40 -4
  20. sglang/srt/disaggregation/utils.py +15 -0
  21. sglang/srt/entrypoints/verl_engine.py +7 -5
  22. sglang/srt/layers/activation.py +6 -8
  23. sglang/srt/layers/attention/flashattention_backend.py +114 -71
  24. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  25. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  26. sglang/srt/layers/attention/triton_backend.py +6 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  28. sglang/srt/layers/layernorm.py +1 -1
  29. sglang/srt/layers/linear.py +17 -3
  30. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  31. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  34. sglang/srt/layers/moe/topk.py +27 -30
  35. sglang/srt/layers/parameter.py +0 -2
  36. sglang/srt/layers/quantization/__init__.py +1 -0
  37. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  38. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +8 -2
  39. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  40. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  41. sglang/srt/layers/quantization/fp8.py +115 -132
  42. sglang/srt/layers/quantization/fp8_kernel.py +213 -57
  43. sglang/srt/layers/quantization/fp8_utils.py +187 -262
  44. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  45. sglang/srt/layers/quantization/utils.py +5 -11
  46. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  47. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  48. sglang/srt/layers/radix_attention.py +15 -0
  49. sglang/srt/layers/rotary_embedding.py +3 -2
  50. sglang/srt/layers/sampler.py +5 -10
  51. sglang/srt/lora/backend/base_backend.py +18 -2
  52. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  53. sglang/srt/lora/backend/triton_backend.py +1 -1
  54. sglang/srt/lora/layers.py +1 -1
  55. sglang/srt/lora/lora.py +1 -1
  56. sglang/srt/lora/lora_manager.py +1 -1
  57. sglang/srt/managers/detokenizer_manager.py +0 -1
  58. sglang/srt/managers/io_struct.py +1 -0
  59. sglang/srt/managers/mm_utils.py +4 -3
  60. sglang/srt/managers/multimodal_processor.py +0 -2
  61. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  62. sglang/srt/managers/schedule_batch.py +2 -4
  63. sglang/srt/managers/scheduler.py +12 -71
  64. sglang/srt/managers/tokenizer_manager.py +1 -0
  65. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  66. sglang/srt/mem_cache/memory_pool.py +7 -2
  67. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  68. sglang/srt/model_executor/model_runner.py +20 -27
  69. sglang/srt/models/bert.py +398 -0
  70. sglang/srt/models/deepseek.py +1 -1
  71. sglang/srt/models/deepseek_nextn.py +74 -70
  72. sglang/srt/models/deepseek_v2.py +289 -348
  73. sglang/srt/models/llama.py +5 -5
  74. sglang/srt/models/minicpm3.py +29 -201
  75. sglang/srt/models/qwen2.py +4 -1
  76. sglang/srt/models/qwen2_moe.py +14 -13
  77. sglang/srt/models/qwen3.py +335 -0
  78. sglang/srt/models/qwen3_moe.py +423 -0
  79. sglang/srt/reasoning_parser.py +0 -1
  80. sglang/srt/sampling/sampling_batch_info.py +2 -3
  81. sglang/srt/server_args.py +34 -32
  82. sglang/srt/speculative/eagle_worker.py +4 -7
  83. sglang/srt/utils.py +16 -1
  84. sglang/test/runners.py +5 -1
  85. sglang/test/test_block_fp8.py +167 -0
  86. sglang/test/test_custom_ops.py +1 -1
  87. sglang/version.py +1 -1
  88. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +3 -3
  89. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +92 -91
  90. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  91. sglang/lang/__init__.py +0 -0
  92. sglang/srt/lora/backend/__init__.py +0 -25
  93. sglang/srt/server.py +0 -18
  94. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  95. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -3,9 +3,20 @@ from typing import List, Optional, Tuple
3
3
 
4
4
  import torch
5
5
 
6
+ from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
7
+
8
+ try:
9
+ from vllm import _custom_ops as ops
10
+
11
+ VLLM_AVAILABLE = True
12
+ except ImportError:
13
+ VLLM_AVAILABLE = False
14
+
6
15
  from sglang.srt.layers.quantization.fp8_kernel import (
7
16
  _enable_jit_deepgemm,
8
17
  per_token_group_quant_fp8,
18
+ scaled_fp8_quant,
19
+ sglang_per_token_quant_fp8,
9
20
  static_quant_fp8,
10
21
  w8a8_block_fp8_matmul,
11
22
  )
@@ -17,30 +28,20 @@ from sglang.srt.utils import (
17
28
  is_hip,
18
29
  )
19
30
 
20
- try:
21
- import vllm
22
- from vllm import _custom_ops as ops
23
-
24
- VLLM_AVAILABLE = True
25
- except ImportError:
26
- VLLM_AVAILABLE = False
27
-
28
- use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
29
-
30
31
  _is_hip = is_hip()
32
+ _is_cuda = is_cuda()
33
+
31
34
  if _is_hip and get_bool_env_var("CK_MOE"):
32
35
  from aiter import gemm_a8w8_blockscale
33
36
 
34
- _is_cuda = is_cuda()
35
37
  if _is_cuda:
36
38
  from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
37
39
 
38
- from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
39
- from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
40
+ use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
40
41
 
41
42
  # Input scaling factors are no longer optional in _scaled_mm starting
42
43
  # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
43
- TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
44
+ TORCH_DEVICE_IDENTITY = None
44
45
 
45
46
  _TORCH_VERSION = torch.__version__.split("+")[0]
46
47
  try:
@@ -143,7 +144,7 @@ def apply_w8a8_block_fp8_linear(
143
144
  gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
144
145
  else:
145
146
  if _enable_jit_deepgemm:
146
- q_input, x_scale = per_token_group_quant_fp8(
147
+ q_input, x_scale = sglang_per_token_group_quant_fp8(
147
148
  input_2d,
148
149
  block_size[1],
149
150
  column_major_scales=True,
@@ -214,7 +215,7 @@ def block_quant_to_tensor_quant(
214
215
  x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
215
216
 
216
217
  x_q_tensor, scale = (
217
- sgl_scaled_fp8_quant(x_dq_block)
218
+ scaled_fp8_quant(x_dq_block)
218
219
  if _is_cuda
219
220
  else input_to_float8(x_dq_block, dtype=x_q_block.dtype)
220
221
  )
@@ -227,13 +228,50 @@ def channel_quant_to_tensor_quant(
227
228
  ) -> Tuple[torch.Tensor, torch.Tensor]:
228
229
  x_dq_channel = x_q_channel.to(torch.float32) * x_s
229
230
  x_q_tensor, scale = (
230
- sgl_scaled_fp8_quant(x_dq_channel)
231
+ scaled_fp8_quant(x_dq_channel)
231
232
  if _is_cuda
232
233
  else input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
233
234
  )
234
235
  return x_q_tensor, scale
235
236
 
236
237
 
238
+ def _process_scaled_mm_output(output, input_2d_shape, output_shape):
239
+ if type(output) is tuple and len(output) == 2:
240
+ output = output[0]
241
+ return torch.narrow(output, 0, 0, input_2d_shape[0]).view(*output_shape)
242
+
243
+
244
+ def _apply_fallback_scaled_mm(
245
+ qinput,
246
+ weight,
247
+ x_scale,
248
+ weight_scale,
249
+ input_2d_shape,
250
+ output_shape,
251
+ bias,
252
+ input_dtype,
253
+ ):
254
+ global TORCH_DEVICE_IDENTITY
255
+ if TORCH_DEVICE_IDENTITY is None:
256
+ TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32, device=weight.device)
257
+
258
+ output = torch._scaled_mm(
259
+ qinput,
260
+ weight,
261
+ scale_a=TORCH_DEVICE_IDENTITY,
262
+ scale_b=TORCH_DEVICE_IDENTITY,
263
+ out_dtype=torch.float32,
264
+ )
265
+
266
+ output = _process_scaled_mm_output(output, input_2d_shape, output_shape)
267
+ x_scale = torch.narrow(x_scale, 0, 0, input_2d_shape[0])
268
+
269
+ output = output * x_scale * weight_scale.t()
270
+ if bias is not None:
271
+ output = output + bias
272
+ return output.to(dtype=input_dtype)
273
+
274
+
237
275
  def apply_fp8_linear(
238
276
  input: torch.Tensor,
239
277
  weight: torch.Tensor,
@@ -241,216 +279,33 @@ def apply_fp8_linear(
241
279
  input_scale: Optional[torch.Tensor] = None,
242
280
  input_scale_ub: Optional[torch.Tensor] = None,
243
281
  bias: Optional[torch.Tensor] = None,
244
- cutlass_fp8_supported: bool = True,
282
+ cutlass_fp8_supported: bool = cutlass_fp8_supported(),
245
283
  use_per_token_if_dynamic: bool = False,
284
+ pad_output: Optional[bool] = None,
285
+ compressed_tensor_quant: bool = False,
246
286
  ) -> torch.Tensor:
287
+ # Note: we pad the input because torch._scaled_mm is more performant
288
+ # for matrices with batch dimension > 16.
289
+ # This could change in the future.
290
+ # We also don't pad when using torch.compile,
291
+ # as it breaks with dynamic shapes.
292
+ if pad_output is None:
293
+ pad_output = not get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE")
294
+ output_padding = 17 if pad_output else None
295
+
247
296
  # View input as 2D matrix for fp8 methods
248
297
  input_2d = input.view(-1, input.shape[-1])
249
298
  output_shape = [*input.shape[:-1], weight.shape[1]]
250
299
 
251
- # cutlass w8a8 fp8 sgl-kernel only supports per-token scale
252
- if input_scale is not None:
253
- assert input_scale.numel() == 1
254
- # broadcast per-tensor scale to per-token scale when supporting cutlass
255
- qinput, x_scale = static_quant_fp8(
256
- input_2d, input_scale, repeat_scale=cutlass_fp8_supported
257
- )
258
- else:
259
- # default use per-token quantization if dynamic
260
- if _is_cuda:
261
- qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
262
- else:
263
- # TODO(kkhuang): temporarily enforce per-tensor activation scaling if weight is per-tensor scaling
264
- # final solution should be: 1. add support to per-tensor activation scaling.
265
- # 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
266
- if _is_hip and weight_scale.numel() == 1:
267
- qinput, x_scale = ops.scaled_fp8_quant(
268
- input_2d,
269
- input_scale,
270
- use_per_token_if_dynamic=use_per_token_if_dynamic,
271
- )
272
- else:
273
- qinput, x_scale = per_token_group_quant_fp8(
274
- input_2d, group_size=input_2d.shape[1]
275
- )
276
-
277
- if cutlass_fp8_supported:
278
- try:
279
- if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
280
- # Fall back to vllm cutlass w8a8 fp8 kernel
281
- output = ops.cutlass_scaled_mm(
282
- qinput,
283
- weight,
284
- out_dtype=input.dtype,
285
- scale_a=x_scale,
286
- scale_b=weight_scale,
287
- bias=bias,
288
- )
289
- else:
290
- assert (
291
- weight_scale.numel() == weight.shape[1]
292
- ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
293
- output = fp8_scaled_mm(
294
- qinput,
295
- weight,
296
- x_scale,
297
- weight_scale,
298
- out_dtype=input.dtype,
299
- bias=bias,
300
- )
301
- return output.view(*output_shape)
302
- except (ImportError, NameError, AttributeError):
303
- pass
304
-
305
- # torch.scaled_mm supports per tensor weights + activations only
306
- # so fallback to naive if per channel or per token
307
- else:
308
- per_tensor_weights = weight_scale.numel() == 1
309
- per_tensor_activations = x_scale.numel() == 1
310
-
311
- if per_tensor_weights and per_tensor_activations:
312
- # Fused GEMM_DQ
313
- output = torch._scaled_mm(
314
- qinput,
315
- weight,
316
- out_dtype=input.dtype,
317
- scale_a=x_scale,
318
- scale_b=weight_scale,
319
- bias=bias,
320
- )
321
- # A fix for discrepancy in scaled_mm which returns tuple
322
- # for torch < 2.5 and a single value in torch >= 2.5
323
- if type(output) is tuple and len(output) == 2:
324
- output = output[0]
325
-
326
- return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
327
-
328
- else:
329
- # Fallback for channelwise case, where we use unfused DQ
330
- # due to limitations with scaled_mm
331
-
332
- # Symmetric quantized GEMM by definition computes the following:
333
- # C = (s_x * X) (s_w * W) + bias
334
- # This is equivalent to dequantizing the weights and activations
335
- # before applying a GEMM.
336
- #
337
- # In order to compute quantized operands, a quantized kernel
338
- # will rewrite the above like so:
339
- # C = s_w * s_x * (X * W) + bias
340
- #
341
- # For the scaled_mm fallback case, we break this down, since it
342
- # does not support s_w being a vector.
343
-
344
- # Making sure the dummy tensor is on the same device as the weight
345
- global TORCH_DEVICE_IDENTITY
346
- if TORCH_DEVICE_IDENTITY.device != weight.device:
347
- TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
348
-
349
- # GEMM
350
- # This computes C = (X * W).
351
- # Output in fp32 to allow subsequent ops to happen in-place
352
- output = torch._scaled_mm(
353
- qinput,
354
- weight,
355
- scale_a=TORCH_DEVICE_IDENTITY,
356
- scale_b=TORCH_DEVICE_IDENTITY,
357
- out_dtype=torch.float32,
358
- )
359
- # A fix for discrepancy in scaled_mm which returns tuple
360
- # for torch < 2.5 and a single value in torch >= 2.5
361
- if type(output) is tuple and len(output) == 2:
362
- output = output[0]
363
- # Unpad (undo num_token_padding)
364
- output = torch.narrow(output, 0, 0, input_2d.shape[0])
365
- x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
366
-
367
- # DQ
368
- # C = sw * sx * (X * W) + bias
369
- output = output * x_scale * weight_scale.t()
370
- if bias is not None:
371
- output = output + bias
372
- return output.to(dtype=input.dtype).view(*output_shape)
373
-
374
-
375
- def maybe_create_device_identity():
376
- # Allocate dummy ones tensor for torch._scaled_mm
377
- global TORCH_DEVICE_IDENTITY
378
- if TORCH_DEVICE_IDENTITY is None:
379
- TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
380
-
381
-
382
- # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
383
- # TODO(luka): follow similar pattern for marlin and block-fp8-linear
384
- # https://github.com/vllm-project/vllm/issues/14397
385
- class Fp8LinearOp:
386
- """
387
- This class executes a FP8 linear layer using cutlass if supported and
388
- torch.scaled_mm otherwise.
389
- It needs to be a class instead of a method so that config can be read
390
- in the __init__ method, as reading config is not allowed inside forward.
391
- """
392
-
393
- def __init__(
394
- self,
395
- cutlass_fp8_supported: bool = cutlass_fp8_supported(),
396
- use_per_token_if_dynamic: bool = False,
397
- pad_output: Optional[bool] = None,
398
- ):
399
- self.cutlass_fp8_supported = cutlass_fp8_supported
400
- self.use_per_token_if_dynamic = use_per_token_if_dynamic
401
-
402
- # Note: we pad the input because torch._scaled_mm is more performant
403
- # for matrices with batch dimension > 16.
404
- # This could change in the future.
405
- # We also don't pad when using torch.compile,
406
- # as it breaks with dynamic shapes.
407
- if pad_output is None:
408
- enable_torch_compile = os.environ.get(
409
- "SGLANG_ENABLE_TORCH_COMPILE", "0"
410
- ).lower() in ("1", "true", "yes")
411
- pad_output = not enable_torch_compile
412
- self.output_padding = 17 if pad_output else None
413
-
414
- def apply(
415
- self,
416
- input: torch.Tensor,
417
- weight: torch.Tensor,
418
- weight_scale: torch.Tensor,
419
- input_scale: Optional[torch.Tensor] = None,
420
- input_scale_ub: Optional[torch.Tensor] = None,
421
- bias: Optional[torch.Tensor] = None,
422
- # TODO(luka) remove this parameter in favor of __init__
423
- use_per_token_if_dynamic: Optional[bool] = None,
424
- ) -> torch.Tensor:
425
- # ops.scaled_fp8_quant supports both dynamic and static quant.
426
- # If dynamic, layer.input_scale is None and x_scale computed from x.
427
- # If static, layer.input_scale is scalar and x_scale is input_scale.
428
-
429
- # View input as 2D matrix for fp8 methods
430
- input_2d = input.view(-1, input.shape[-1])
431
- output_shape = [*input.shape[:-1], weight.shape[1]]
432
-
433
- # TODO(luka) this is here because currently MLA only decides this
434
- # during the forward method instead of in __init__.
435
- if use_per_token_if_dynamic is None:
436
- use_per_token_if_dynamic = self.use_per_token_if_dynamic
437
-
300
+ if compressed_tensor_quant:
438
301
  # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
439
302
  # for sgl-kernel fp8_scaled_mm, it support per channel W now
440
- if self.cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
441
- if _is_cuda:
442
- qinput, x_scale = sgl_scaled_fp8_quant(
443
- input_2d,
444
- input_scale,
445
- use_per_token_if_dynamic=use_per_token_if_dynamic,
446
- )
447
- else:
448
- qinput, x_scale = ops.scaled_fp8_quant(
449
- input_2d,
450
- input_scale,
451
- scale_ub=input_scale_ub,
452
- use_per_token_if_dynamic=use_per_token_if_dynamic,
453
- )
303
+ if cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
304
+ qinput, x_scale = scaled_fp8_quant(
305
+ input_2d,
306
+ input_scale,
307
+ use_per_token_if_dynamic=use_per_token_if_dynamic,
308
+ )
454
309
 
455
310
  # Fused GEMM_DQ
456
311
  if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
@@ -481,20 +336,21 @@ class Fp8LinearOp:
481
336
  # so fallback to naive if per channel or per token
482
337
  else:
483
338
  # Maybe apply padding to output, see comment in __init__
484
- if _is_cuda:
485
- qinput, x_scale = sgl_scaled_fp8_quant(
339
+ qinput, x_scale = (
340
+ scaled_fp8_quant(
486
341
  input_2d,
487
342
  input_scale,
488
- num_token_padding=self.output_padding,
343
+ num_token_padding=output_padding,
489
344
  use_per_token_if_dynamic=use_per_token_if_dynamic,
490
345
  )
491
- else:
492
- qinput, x_scale = ops.scaled_fp8_quant(
346
+ if _is_cuda
347
+ else ops.scaled_fp8_quant(
493
348
  input_2d,
494
349
  input_scale,
495
- num_token_padding=self.output_padding,
350
+ num_token_padding=output_padding,
496
351
  use_per_token_if_dynamic=use_per_token_if_dynamic,
497
352
  )
353
+ )
498
354
 
499
355
  per_tensor_weights = weight_scale.numel() == 1
500
356
  per_tensor_activations = x_scale.numel() == 1
@@ -509,12 +365,7 @@ class Fp8LinearOp:
509
365
  scale_b=weight_scale,
510
366
  bias=bias,
511
367
  )
512
- # A fix for discrepancy in scaled_mm which returns tuple
513
- # for torch < 2.5 and a single value in torch >= 2.5
514
- if type(output) is tuple and len(output) == 2:
515
- output = output[0]
516
-
517
- return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
368
+ return _process_scaled_mm_output(output, input_2d.shape, output_shape)
518
369
 
519
370
  elif (
520
371
  use_per_token_if_dynamic
@@ -537,10 +388,7 @@ class Fp8LinearOp:
537
388
  scale_b=weight_scale.t(),
538
389
  bias=bias,
539
390
  )
540
-
541
- output = torch.narrow(output, 0, 0, input_2d.shape[0])
542
- output = output.view(*output_shape)
543
- return output
391
+ return _process_scaled_mm_output(output, input_2d.shape, output_shape)
544
392
 
545
393
  else:
546
394
  # Fallback for channelwise case, where we use unfused DQ
@@ -557,33 +405,110 @@ class Fp8LinearOp:
557
405
  #
558
406
  # For the scaled_mm fallback case, we break this down, since it
559
407
  # does not support s_w being a vector.
560
-
561
- # GEMM
562
- # This computes C = (X * W).
563
- # Output in fp32 to allow subsequent ops to happen in-place
564
-
565
- global TORCH_DEVICE_IDENTITY
566
- if TORCH_DEVICE_IDENTITY.device != weight.device:
567
- TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
568
-
569
- output = torch._scaled_mm(
408
+ return _apply_fallback_scaled_mm(
570
409
  qinput,
571
410
  weight,
572
- scale_a=TORCH_DEVICE_IDENTITY,
573
- scale_b=TORCH_DEVICE_IDENTITY,
574
- out_dtype=torch.float32,
411
+ x_scale,
412
+ weight_scale,
413
+ input_2d.shape,
414
+ output_shape,
415
+ bias,
416
+ input.dtype,
575
417
  )
576
- # A fix for discrepancy in scaled_mm which returns tuple
577
- # for torch < 2.5 and a single value in torch >= 2.5
578
- if type(output) is tuple and len(output) == 2:
579
- output = output[0]
580
- # Unpad (undo num_token_padding)
581
- output = torch.narrow(output, 0, 0, input_2d.shape[0])
582
- x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
583
-
584
- # DQ
585
- # C = sw * sx * (X * W) + bias
586
- output = output * x_scale * weight_scale.t()
587
- if bias is not None:
588
- output = output + bias
589
- return output.to(dtype=input.dtype).view(*output_shape)
418
+ else:
419
+ # cutlass w8a8 fp8 sgl-kernel only supports per-token scale
420
+ if input_scale is not None:
421
+ assert input_scale.numel() == 1
422
+ # broadcast per-tensor scale to per-token scale when supporting cutlass
423
+ qinput, x_scale = static_quant_fp8(
424
+ input_2d, input_scale, repeat_scale=cutlass_fp8_supported
425
+ )
426
+ else:
427
+ # default use per-token quantization if dynamic
428
+ if _is_cuda:
429
+ qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
430
+ else:
431
+ # TODO(kkhuang): temporarily enforce per-tensor activation scaling if weight is per-tensor scaling
432
+ # final solution should be: 1. add support to per-tensor activation scaling.
433
+ # 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
434
+ if _is_hip and weight_scale.numel() == 1:
435
+ qinput, x_scale = ops.scaled_fp8_quant(
436
+ input_2d,
437
+ input_scale,
438
+ use_per_token_if_dynamic=use_per_token_if_dynamic,
439
+ )
440
+ else:
441
+ qinput, x_scale = per_token_group_quant_fp8(
442
+ input_2d, group_size=input_2d.shape[1]
443
+ )
444
+
445
+ if cutlass_fp8_supported:
446
+ try:
447
+ if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
448
+ # Fall back to vllm cutlass w8a8 fp8 kernel
449
+ output = ops.cutlass_scaled_mm(
450
+ qinput,
451
+ weight,
452
+ out_dtype=input.dtype,
453
+ scale_a=x_scale,
454
+ scale_b=weight_scale,
455
+ bias=bias,
456
+ )
457
+ else:
458
+ assert (
459
+ weight_scale.numel() == weight.shape[1]
460
+ ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
461
+ output = fp8_scaled_mm(
462
+ qinput,
463
+ weight,
464
+ x_scale,
465
+ weight_scale,
466
+ out_dtype=input.dtype,
467
+ bias=bias,
468
+ )
469
+ return output.view(*output_shape)
470
+ except (ImportError, NameError, AttributeError):
471
+ pass
472
+
473
+ # torch.scaled_mm supports per tensor weights + activations only
474
+ # so fallback to naive if per channel or per token
475
+ per_tensor_weights = weight_scale.numel() == 1
476
+ per_tensor_activations = x_scale.numel() == 1
477
+
478
+ if per_tensor_weights and per_tensor_activations:
479
+ # Fused GEMM_DQ
480
+ output = torch._scaled_mm(
481
+ qinput,
482
+ weight,
483
+ out_dtype=input.dtype,
484
+ scale_a=x_scale,
485
+ scale_b=weight_scale,
486
+ bias=bias,
487
+ )
488
+ return _process_scaled_mm_output(output, input_2d.shape, output_shape)
489
+
490
+ else:
491
+ # Fallback for channelwise case, where we use unfused DQ
492
+ # due to limitations with scaled_mm
493
+
494
+ # Symmetric quantized GEMM by definition computes the following:
495
+ # C = (s_x * X) (s_w * W) + bias
496
+ # This is equivalent to dequantizing the weights and activations
497
+ # before applying a GEMM.
498
+ #
499
+ # In order to compute quantized operands, a quantized kernel
500
+ # will rewrite the above like so:
501
+ # C = s_w * s_x * (X * W) + bias
502
+ #
503
+ # For the scaled_mm fallback case, we break this down, since it
504
+ # does not support s_w being a vector.
505
+ return _apply_fallback_scaled_mm(
506
+ qinput,
507
+ weight,
508
+ x_scale,
509
+ weight_scale,
510
+ input_2d.shape,
511
+ output_shape,
512
+ bias,
513
+ input.dtype,
514
+ )
@@ -347,6 +347,7 @@ class MoeWNA16Method:
347
347
  apply_router_weight_on_input: bool = False,
348
348
  inplace: bool = True,
349
349
  no_combine: bool = False,
350
+ routed_scaling_factor: Optional[float] = None,
350
351
  ) -> torch.Tensor:
351
352
  # avoid circular import
352
353
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -363,6 +364,7 @@ class MoeWNA16Method:
363
364
  num_expert_group=num_expert_group,
364
365
  custom_routing_function=custom_routing_function,
365
366
  correction_bias=correction_bias,
367
+ routed_scaling_factor=routed_scaling_factor,
366
368
  )
367
369
 
368
370
  weight_bits = self.quant_config.weight_bits
@@ -1,18 +1,17 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
2
2
 
3
3
  from types import MappingProxyType
4
- from typing import List, Mapping, Optional, Tuple, Union
4
+ from typing import List, Mapping, Tuple, Union
5
5
 
6
6
  import torch
7
7
 
8
+ from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
8
9
  from sglang.srt.utils import is_cuda
9
10
 
10
11
  _is_cuda = is_cuda()
11
12
 
12
- if _is_cuda:
13
- from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
14
- else:
15
- from vllm import _custom_ops as vllm_ops
13
+ if not _is_cuda:
14
+ from vllm._custom_ops import scaled_fp8_quant
16
15
 
17
16
 
18
17
  def is_fp8_fnuz() -> bool:
@@ -116,12 +115,7 @@ def requantize_with_max_scale(
116
115
  for idx, logical_width in enumerate(logical_widths):
117
116
  end = start + logical_width
118
117
  weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx])
119
- if _is_cuda:
120
- weight[start:end, :], _ = sgl_scaled_fp8_quant(weight_dq, max_w_scale)
121
- else:
122
- weight[start:end, :], _ = vllm_ops.scaled_fp8_quant(
123
- weight_dq, max_w_scale
124
- )
118
+ weight[start:end, :], _ = scaled_fp8_quant(weight_dq, max_w_scale)
125
119
  start = end
126
120
 
127
121
  return max_w_scale, weight
@@ -294,6 +294,7 @@ class W8A8FP8MoEMethod:
294
294
  activation: str = "silu",
295
295
  inplace: bool = True,
296
296
  no_combine: bool = False,
297
+ routed_scaling_factor: Optional[float] = None,
297
298
  ) -> torch.Tensor:
298
299
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
299
300
  from sglang.srt.layers.moe.topk import select_experts
@@ -309,6 +310,7 @@ class W8A8FP8MoEMethod:
309
310
  num_expert_group=num_expert_group,
310
311
  custom_routing_function=custom_routing_function,
311
312
  correction_bias=correction_bias,
313
+ routed_scaling_factor=routed_scaling_factor,
312
314
  )
313
315
 
314
316
  return fused_experts(
@@ -1,13 +1,6 @@
1
1
  from typing import Any, Callable, Dict, List, Optional
2
2
 
3
3
  import torch
4
-
5
- from sglang.srt.utils import is_cuda_available, set_weight_attrs
6
-
7
- is_cuda = is_cuda_available()
8
- if is_cuda:
9
- from sgl_kernel import int8_scaled_mm
10
-
11
4
  from torch.nn.parameter import Parameter
12
5
 
13
6
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
@@ -18,6 +11,11 @@ from sglang.srt.layers.quantization.base_config import (
18
11
  QuantizeMethodBase,
19
12
  )
20
13
  from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
14
+ from sglang.srt.utils import is_cuda_available, set_weight_attrs
15
+
16
+ is_cuda = is_cuda_available()
17
+ if is_cuda:
18
+ from sgl_kernel import int8_scaled_mm
21
19
 
22
20
 
23
21
  class W8A8Int8Config(QuantizationConfig):
@@ -233,6 +231,7 @@ class W8A8Int8MoEMethod:
233
231
  apply_router_weight_on_input: bool = False,
234
232
  inplace: bool = True,
235
233
  no_combine: bool = False,
234
+ routed_scaling_factor: Optional[float] = None,
236
235
  ) -> torch.Tensor:
237
236
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
238
237
  from sglang.srt.layers.moe.topk import select_experts
@@ -248,6 +247,7 @@ class W8A8Int8MoEMethod:
248
247
  num_expert_group=num_expert_group,
249
248
  custom_routing_function=custom_routing_function,
250
249
  correction_bias=correction_bias,
250
+ routed_scaling_factor=routed_scaling_factor,
251
251
  )
252
252
 
253
253
  return fused_experts(