sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. sglang/bench_serving.py +1 -1
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/function_call_parser.py +33 -2
  14. sglang/srt/hf_transformers_utils.py +16 -1
  15. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  17. sglang/srt/layers/attention/triton_backend.py +1 -3
  18. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  21. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  22. sglang/srt/layers/attention/vision.py +43 -62
  23. sglang/srt/layers/dp_attention.py +30 -2
  24. sglang/srt/layers/elementwise.py +411 -0
  25. sglang/srt/layers/linear.py +1 -1
  26. sglang/srt/layers/logits_processor.py +1 -0
  27. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  28. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  37. sglang/srt/layers/moe/router.py +342 -0
  38. sglang/srt/layers/parameter.py +10 -0
  39. sglang/srt/layers/quantization/__init__.py +90 -68
  40. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  41. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/fp8.py +174 -106
  68. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  69. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  70. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  71. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  72. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  73. sglang/srt/layers/rotary_embedding.py +5 -3
  74. sglang/srt/layers/sampler.py +29 -35
  75. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  76. sglang/srt/lora/backend/__init__.py +9 -12
  77. sglang/srt/managers/cache_controller.py +74 -8
  78. sglang/srt/managers/data_parallel_controller.py +1 -1
  79. sglang/srt/managers/image_processor.py +37 -631
  80. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  81. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  82. sglang/srt/managers/image_processors/llava.py +152 -0
  83. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  84. sglang/srt/managers/image_processors/mlama.py +60 -0
  85. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  86. sglang/srt/managers/io_struct.py +32 -15
  87. sglang/srt/managers/multi_modality_padding.py +134 -0
  88. sglang/srt/managers/schedule_batch.py +213 -118
  89. sglang/srt/managers/schedule_policy.py +40 -8
  90. sglang/srt/managers/scheduler.py +176 -683
  91. sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
  92. sglang/srt/managers/tokenizer_manager.py +6 -6
  93. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  94. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  95. sglang/srt/mem_cache/chunk_cache.py +12 -44
  96. sglang/srt/mem_cache/hiradix_cache.py +71 -34
  97. sglang/srt/mem_cache/memory_pool.py +81 -17
  98. sglang/srt/mem_cache/paged_allocator.py +283 -0
  99. sglang/srt/mem_cache/radix_cache.py +117 -36
  100. sglang/srt/model_executor/cuda_graph_runner.py +68 -20
  101. sglang/srt/model_executor/forward_batch_info.py +23 -10
  102. sglang/srt/model_executor/model_runner.py +63 -63
  103. sglang/srt/model_loader/loader.py +2 -1
  104. sglang/srt/model_loader/weight_utils.py +1 -1
  105. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  106. sglang/srt/models/deepseek_nextn.py +23 -3
  107. sglang/srt/models/deepseek_v2.py +200 -191
  108. sglang/srt/models/grok.py +374 -119
  109. sglang/srt/models/minicpmv.py +28 -89
  110. sglang/srt/models/mllama.py +1 -1
  111. sglang/srt/models/qwen2.py +0 -1
  112. sglang/srt/models/qwen2_5_vl.py +25 -50
  113. sglang/srt/models/qwen2_vl.py +33 -49
  114. sglang/srt/openai_api/adapter.py +59 -35
  115. sglang/srt/openai_api/protocol.py +8 -1
  116. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  117. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  118. sglang/srt/server_args.py +24 -16
  119. sglang/srt/speculative/eagle_worker.py +75 -39
  120. sglang/srt/utils.py +104 -9
  121. sglang/test/runners.py +104 -10
  122. sglang/test/test_block_fp8.py +106 -16
  123. sglang/test/test_custom_ops.py +88 -0
  124. sglang/test/test_utils.py +20 -4
  125. sglang/utils.py +0 -4
  126. sglang/version.py +1 -1
  127. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
  128. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
  129. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
  130. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
  131. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -22,17 +22,54 @@ import torch
22
22
  import triton
23
23
  import triton.language as tl
24
24
 
25
- from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
26
-
27
- is_hip_ = is_hip()
28
- fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
29
-
30
- _is_cuda = torch.cuda.is_available() and torch.version.cuda
25
+ from sglang.srt.utils import (
26
+ direct_register_custom_op,
27
+ get_device_core_count,
28
+ get_device_name,
29
+ is_cuda,
30
+ is_hip,
31
+ supports_custom_op,
32
+ )
33
+
34
+ _is_hip = is_hip()
35
+ fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
36
+
37
+ _is_cuda = is_cuda()
31
38
  if _is_cuda:
32
- from sgl_kernel import sgl_per_token_group_quant_fp8
39
+ import deep_gemm # `pip install "sgl-kernel>=0.0.4.post3"`
40
+ from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
33
41
 
34
42
  logger = logging.getLogger(__name__)
35
43
 
44
+ _enable_jit_deepgemm = int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "0"))
45
+
46
+ if supports_custom_op():
47
+
48
+ def deep_gemm_fp8_fp8_bf16_nt(
49
+ A: torch.Tensor,
50
+ As: torch.Tensor,
51
+ B: torch.Tensor,
52
+ Bs: torch.Tensor,
53
+ C: torch.Tensor,
54
+ ) -> None:
55
+ deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
56
+
57
+ def deep_gemm_fp8_fp8_bf16_nt_fake(
58
+ A: torch.Tensor,
59
+ As: torch.Tensor,
60
+ B: torch.Tensor,
61
+ Bs: torch.Tensor,
62
+ C: torch.Tensor,
63
+ ) -> None:
64
+ return
65
+
66
+ direct_register_custom_op(
67
+ op_name="deep_gemm_fp8_fp8_bf16_nt",
68
+ op_func=deep_gemm_fp8_fp8_bf16_nt,
69
+ mutates_args=["C"],
70
+ fake_impl=deep_gemm_fp8_fp8_bf16_nt_fake,
71
+ )
72
+
36
73
 
37
74
  @triton.jit
38
75
  def _per_token_group_quant_fp8(
@@ -70,7 +107,8 @@ def _per_token_group_quant_fp8(
70
107
  # Quant
71
108
  _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
72
109
  y_s = _absmax / fp8_max
73
- y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
110
+ y_s_inv = 1.0 / y_s
111
+ y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
74
112
 
75
113
  tl.store(y_q_ptr + cols, y_q, mask=mask)
76
114
  tl.store(y_s_ptr, y_s)
@@ -140,7 +178,7 @@ def per_token_group_quant_fp8(
140
178
  x: The input tenosr with ndim >= 2.
141
179
  group_size: The group size used for quantization.
142
180
  eps: The minimum to avoid dividing zero.
143
- dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
181
+ dtype: The dype of output tensor.
144
182
 
145
183
  Returns:
146
184
  Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
@@ -153,7 +191,7 @@ def per_token_group_quant_fp8(
153
191
  finfo = torch.finfo(dtype)
154
192
  fp8_max = finfo.max
155
193
 
156
- if is_hip_:
194
+ if _is_hip:
157
195
  fp8_max = 224.0
158
196
 
159
197
  fp8_min = -fp8_max
@@ -241,6 +279,132 @@ def sglang_per_token_group_quant_fp8(
241
279
  return x_q, x_s
242
280
 
243
281
 
282
+ def sglang_per_token_quant_fp8(
283
+ x: torch.Tensor,
284
+ dtype: torch.dtype = fp8_type_,
285
+ ):
286
+ assert x.is_contiguous(), "`x` is not contiguous"
287
+
288
+ x_q = torch.empty_like(x, device=x.device, dtype=dtype)
289
+ x_s = torch.empty(
290
+ x.shape[0],
291
+ 1,
292
+ device=x.device,
293
+ dtype=torch.float32,
294
+ )
295
+
296
+ sgl_per_token_quant_fp8(x, x_q, x_s)
297
+
298
+ return x_q, x_s
299
+
300
+
301
+ @triton.jit
302
+ def _static_quant_fp8(
303
+ # Pointers to inputs and output
304
+ y_ptr,
305
+ y_q_ptr,
306
+ y_s_ptr,
307
+ y_s_repeat_ptr,
308
+ # Stride of input
309
+ y_stride,
310
+ # Collums of input
311
+ N,
312
+ # Information for float8
313
+ fp8_min,
314
+ fp8_max,
315
+ # Meta-parameters
316
+ BLOCK: tl.constexpr,
317
+ REPEAT_SCALE: tl.constexpr,
318
+ ):
319
+ """A Triton-accelerated function to perform quantization using the given scale on a
320
+ tensor
321
+
322
+ This function converts the tensor values into float8 values.
323
+ """
324
+ # Map the program id to the row of X and Y it should compute.
325
+ g_id = tl.program_id(0)
326
+ y_ptr += g_id * y_stride
327
+ y_q_ptr += g_id * y_stride
328
+ if REPEAT_SCALE:
329
+ y_s_repeat_ptr += g_id
330
+
331
+ cols = tl.arange(0, BLOCK) # N <= BLOCK
332
+ mask = cols < N
333
+
334
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
335
+ y_s = tl.load(y_s_ptr).to(tl.float32)
336
+ y_s_inv = 1.0 / y_s
337
+ y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
338
+
339
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
340
+ if REPEAT_SCALE:
341
+ tl.store(y_s_repeat_ptr, y_s)
342
+
343
+
344
+ def static_quant_fp8(
345
+ x: torch.Tensor,
346
+ x_s: torch.Tensor,
347
+ repeat_scale: bool = False,
348
+ dtype: torch.dtype = fp8_type_,
349
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
350
+ """Function to perform static quantization using the given scale on an input tensor `x`.
351
+
352
+ It converts the tensor values into signed float8 values and returns the
353
+ quantized tensor along with the scaling factor used for quantization.
354
+
355
+ Args:
356
+ x: The input tenosr with ndim >= 2.
357
+ x_s: The quantization scale.
358
+ repeat_scale: Whether to broadcast per-tensor scale to per-channel scale.
359
+ dtype: The dype of output tensor.
360
+
361
+ Returns:
362
+ Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
363
+ """
364
+ assert x.is_contiguous(), "`x` is not contiguous"
365
+ assert x_s.numel() == 1, "only supports per-tensor scale"
366
+ finfo = torch.finfo(dtype)
367
+ fp8_max = finfo.max
368
+
369
+ if _is_hip:
370
+ fp8_max = 224.0
371
+
372
+ fp8_min = -fp8_max
373
+
374
+ x_q = torch.empty_like(x, device=x.device, dtype=dtype)
375
+ M = x.numel() // x.shape[-1]
376
+ N = x.shape[-1]
377
+ if repeat_scale:
378
+ x_s_repeat = torch.empty(
379
+ (M, 1),
380
+ device=x.device,
381
+ dtype=torch.float32,
382
+ )
383
+ else:
384
+ x_s_repeat = None
385
+
386
+ BLOCK = triton.next_power_of_2(N)
387
+ # heuristics for number of warps
388
+ num_warps = min(max(BLOCK // 256, 1), 8)
389
+ num_stages = 1
390
+ _static_quant_fp8[(M,)](
391
+ x,
392
+ x_q,
393
+ x_s,
394
+ x_s_repeat,
395
+ N,
396
+ N,
397
+ fp8_min=fp8_min,
398
+ fp8_max=fp8_max,
399
+ BLOCK=BLOCK,
400
+ REPEAT_SCALE=repeat_scale,
401
+ num_warps=num_warps,
402
+ num_stages=num_stages,
403
+ )
404
+ x_s = x_s_repeat if repeat_scale else x_s
405
+ return x_q, x_s
406
+
407
+
244
408
  @triton.jit
245
409
  def _w8a8_block_fp8_matmul(
246
410
  # Pointers to inputs and output
@@ -595,34 +759,42 @@ def w8a8_block_fp8_matmul(
595
759
  num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
596
760
  N, config["BLOCK_SIZE_N"]
597
761
  )
598
- kernel = (
599
- _w8a8_block_fp8_matmul_unrolledx4
600
- if (is_hip_ == True and num_workgroups <= get_device_core_count())
601
- else _w8a8_block_fp8_matmul
602
- )
603
762
 
604
- kernel[grid](
605
- A,
606
- B,
607
- C,
608
- As,
609
- Bs,
610
- M,
611
- N,
612
- K,
613
- block_n,
614
- block_k,
615
- A.stride(-2),
616
- A.stride(-1),
617
- B.stride(1),
618
- B.stride(0),
619
- C.stride(-2),
620
- C.stride(-1),
621
- As.stride(-2),
622
- As.stride(-1),
623
- Bs.stride(1),
624
- Bs.stride(0),
625
- **config,
626
- )
763
+ # deepgemm only support bf16
764
+ if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
765
+ if supports_custom_op():
766
+ torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
767
+ else:
768
+ deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
769
+ else:
770
+ kernel = (
771
+ _w8a8_block_fp8_matmul_unrolledx4
772
+ if (_is_hip == True and num_workgroups <= get_device_core_count())
773
+ else _w8a8_block_fp8_matmul
774
+ )
775
+
776
+ kernel[grid](
777
+ A,
778
+ B,
779
+ C,
780
+ As,
781
+ Bs,
782
+ M,
783
+ N,
784
+ K,
785
+ block_n,
786
+ block_k,
787
+ A.stride(-2),
788
+ A.stride(-1),
789
+ B.stride(1),
790
+ B.stride(0),
791
+ C.stride(-2),
792
+ C.stride(-1),
793
+ As.stride(-2),
794
+ As.stride(-1),
795
+ Bs.stride(1),
796
+ Bs.stride(0),
797
+ **config,
798
+ )
627
799
 
628
800
  return C
@@ -1,23 +1,53 @@
1
- import os
2
1
  from typing import List, Optional, Tuple
3
2
 
4
3
  import torch
5
4
 
6
- from sglang.srt.layers.parameter import RowvLLMParameter, _ColumnvLLMParameter
7
5
  from sglang.srt.layers.quantization.fp8_kernel import (
8
6
  per_token_group_quant_fp8,
7
+ static_quant_fp8,
9
8
  w8a8_block_fp8_matmul,
10
9
  )
11
- from sglang.srt.utils import get_bool_env_var, is_hip
10
+ from sglang.srt.utils import (
11
+ get_bool_env_var,
12
+ get_cuda_version,
13
+ get_device_capability,
14
+ is_cuda,
15
+ is_hip,
16
+ )
17
+
18
+ use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
12
19
 
13
- is_hip_ = is_hip()
14
- if is_hip_ and get_bool_env_var("CK_MOE"):
20
+ _is_hip = is_hip()
21
+ if _is_hip and get_bool_env_var("CK_MOE"):
15
22
  from aiter import gemm_a8w8_blockscale
16
23
 
17
- _is_cuda = torch.cuda.is_available() and torch.version.cuda
24
+ _is_cuda = is_cuda()
18
25
  if _is_cuda:
19
26
  from sgl_kernel import fp8_blockwise_scaled_mm
20
27
 
28
+ from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
29
+
30
+ if use_vllm_cutlass_w8a8_fp8_kernel:
31
+ from vllm import _custom_ops as ops
32
+ else:
33
+ from sgl_kernel import fp8_scaled_mm
34
+
35
+ # Input scaling factors are no longer optional in _scaled_mm starting
36
+ # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
37
+ TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
38
+
39
+
40
+ def cutlass_fp8_supported():
41
+ if not _is_cuda:
42
+ return False
43
+ major, minor = get_device_capability()
44
+ cuda_version = get_cuda_version()
45
+ if major >= 9:
46
+ return cuda_version >= (12, 0)
47
+ elif major == 8 and minor == 9:
48
+ return cuda_version >= (12, 4)
49
+ return False
50
+
21
51
 
22
52
  def normalize_e4m3fn_to_e4m3fnuz(
23
53
  weight: torch.Tensor,
@@ -44,7 +74,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
44
74
 
45
75
 
46
76
  def cutlass_block_fp8_supported() -> bool:
47
- if os.environ.get("SUPPORT_CUTLASS_BLOCK_FP8") is None:
77
+ if get_bool_env_var("SUPPORT_CUTLASS_BLOCK_FP8"):
48
78
  return False
49
79
  if _is_cuda:
50
80
  major, minor = torch.cuda.get_device_capability()
@@ -81,7 +111,7 @@ def apply_w8a8_block_fp8_linear(
81
111
  output = fp8_blockwise_scaled_mm(
82
112
  q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
83
113
  )
84
- elif is_hip_ and get_bool_env_var("CK_MOE"):
114
+ elif _is_hip and get_bool_env_var("CK_MOE"):
85
115
  q_input, x_scale = per_token_group_quant_fp8(
86
116
  input_2d, block_size[1], column_major_scales=False
87
117
  )
@@ -112,7 +142,7 @@ def input_to_float8(
112
142
  min_val, max_val = x.aminmax()
113
143
  amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
114
144
  fp8_max = finfo.max
115
- if is_hip_:
145
+ if _is_hip:
116
146
  fp8_max = 224.0
117
147
  scale = fp8_max / amax
118
148
  x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
@@ -158,10 +188,121 @@ def block_quant_to_tensor_quant(
158
188
  return x_q_tensor, scale
159
189
 
160
190
 
161
- class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
162
- """
163
- Parameter class for weight scales loaded for weights with
164
- block-wise quantization. Uses both column and row parallelism.
165
- """
191
+ def apply_fp8_linear(
192
+ input: torch.Tensor,
193
+ weight: torch.Tensor,
194
+ weight_scale: torch.Tensor,
195
+ input_scale: Optional[torch.Tensor] = None,
196
+ input_scale_ub: Optional[torch.Tensor] = None,
197
+ bias: Optional[torch.Tensor] = None,
198
+ cutlass_fp8_supported: bool = True,
199
+ use_per_token_if_dynamic: bool = False,
200
+ ) -> torch.Tensor:
201
+ # View input as 2D matrix for fp8 methods
202
+ input_2d = input.view(-1, input.shape[-1])
203
+ output_shape = [*input.shape[:-1], weight.shape[1]]
204
+
205
+ # cutlass w8a8 fp8 sgl-kernel only supports per-token scale
206
+ if input_scale is not None:
207
+ assert input_scale.numel() == 1
208
+ # broadcast per-tensor scale to per-token scale when supporting cutlass
209
+ qinput, x_scale = static_quant_fp8(
210
+ input_2d, input_scale, repeat_scale=cutlass_fp8_supported
211
+ )
212
+ else:
213
+ # default use per-token quantization if dynamic
214
+ if _is_cuda:
215
+ qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
216
+ else:
217
+ qinput, x_scale = per_token_group_quant_fp8(
218
+ input_2d, group_size=input_2d.shape[1]
219
+ )
220
+
221
+ if cutlass_fp8_supported:
222
+ if use_vllm_cutlass_w8a8_fp8_kernel:
223
+ # Fall back to vllm cutlass w8a8 fp8 kernel
224
+ output = ops.cutlass_scaled_mm(
225
+ qinput,
226
+ weight,
227
+ out_dtype=input.dtype,
228
+ scale_a=x_scale,
229
+ scale_b=weight_scale,
230
+ bias=bias,
231
+ )
232
+ else:
233
+ assert (
234
+ weight_scale.numel() == weight.shape[1]
235
+ ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
236
+ output = fp8_scaled_mm(
237
+ qinput, weight, x_scale, weight_scale, out_dtype=input.dtype, bias=bias
238
+ )
239
+ return output.view(*output_shape)
240
+
241
+ # torch.scaled_mm supports per tensor weights + activations only
242
+ # so fallback to naive if per channel or per token
243
+ else:
244
+ per_tensor_weights = weight_scale.numel() == 1
245
+ per_tensor_activations = x_scale.numel() == 1
246
+
247
+ if per_tensor_weights and per_tensor_activations:
248
+ # Fused GEMM_DQ
249
+ output = torch._scaled_mm(
250
+ qinput,
251
+ weight,
252
+ out_dtype=input.dtype,
253
+ scale_a=x_scale,
254
+ scale_b=weight_scale,
255
+ bias=bias,
256
+ )
257
+ # A fix for discrepancy in scaled_mm which returns tuple
258
+ # for torch < 2.5 and a single value in torch >= 2.5
259
+ if type(output) is tuple and len(output) == 2:
260
+ output = output[0]
261
+
262
+ return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
263
+
264
+ else:
265
+ # Fallback for channelwise case, where we use unfused DQ
266
+ # due to limitations with scaled_mm
267
+
268
+ # Symmetric quantized GEMM by definition computes the following:
269
+ # C = (s_x * X) (s_w * W) + bias
270
+ # This is equivalent to dequantizing the weights and activations
271
+ # before applying a GEMM.
272
+ #
273
+ # In order to compute quantized operands, a quantized kernel
274
+ # will rewrite the above like so:
275
+ # C = s_w * s_x * (X * W) + bias
276
+ #
277
+ # For the scaled_mm fallback case, we break this down, since it
278
+ # does not support s_w being a vector.
279
+
280
+ # Making sure the dummy tensor is on the same device as the weight
281
+ global TORCH_DEVICE_IDENTITY
282
+ if TORCH_DEVICE_IDENTITY.device != weight.device:
283
+ TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
284
+
285
+ # GEMM
286
+ # This computes C = (X * W).
287
+ # Output in fp32 to allow subsequent ops to happen in-place
288
+ output = torch._scaled_mm(
289
+ qinput,
290
+ weight,
291
+ scale_a=TORCH_DEVICE_IDENTITY,
292
+ scale_b=TORCH_DEVICE_IDENTITY,
293
+ out_dtype=torch.float32,
294
+ )
295
+ # A fix for discrepancy in scaled_mm which returns tuple
296
+ # for torch < 2.5 and a single value in torch >= 2.5
297
+ if type(output) is tuple and len(output) == 2:
298
+ output = output[0]
299
+ # Unpad (undo num_token_padding)
300
+ output = torch.narrow(output, 0, 0, input_2d.shape[0])
301
+ x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
166
302
 
167
- pass
303
+ # DQ
304
+ # C = sw * sx * (X * W) + bias
305
+ output = output * x_scale * weight_scale.t()
306
+ if bias is not None:
307
+ output = output + bias
308
+ return output.to(dtype=input.dtype).view(*output_shape)
@@ -7,7 +7,7 @@ import torch
7
7
  from torch.nn.parameter import Parameter
8
8
  from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
9
9
  from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
10
- apply_fp8_linear,
10
+ convert_to_channelwise,
11
11
  cutlass_fp8_supported,
12
12
  requantize_with_max_scale,
13
13
  )
@@ -19,6 +19,7 @@ from sglang.srt.layers.quantization.base_config import (
19
19
  QuantizationConfig,
20
20
  QuantizeMethodBase,
21
21
  )
22
+ from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear
22
23
 
23
24
  # Initialize logger for the module
24
25
  logger = logging.getLogger(__name__)
@@ -161,6 +162,9 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
161
162
  layer.weight, layer.weight_scale, layer.logical_widths
162
163
  )
163
164
  layer.weight = Parameter(quantized_weight.t(), requires_grad=False)
165
+ # cutlass sgl-kernel only supports per-channel scale
166
+ if self.cutlass_fp8_supported:
167
+ max_w_scale = convert_to_channelwise(max_w_scale, layer.logical_widths)
164
168
  layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
165
169
  layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
166
170
 
@@ -0,0 +1,128 @@
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ import torch
4
+ from torch.nn.parameter import Parameter
5
+
6
+ from sglang.srt.layers.linear import LinearMethodBase
7
+ from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
8
+ from sglang.srt.layers.quantization.base_config import (
9
+ QuantizationConfig,
10
+ QuantizeMethodBase,
11
+ )
12
+ from sglang.srt.layers.quantization.fp8_utils import (
13
+ apply_fp8_linear,
14
+ cutlass_fp8_supported,
15
+ normalize_e4m3fn_to_e4m3fnuz,
16
+ )
17
+ from sglang.srt.utils import is_hip
18
+
19
+ _is_hip = is_hip()
20
+
21
+
22
+ class W8A8Fp8Config(QuantizationConfig):
23
+ """Config class for W8A8 FP8 Quantization.
24
+
25
+ - Weight: static, per-channel, symmetric
26
+ - Activation: dynamic, per-token, symmetric
27
+ """
28
+
29
+ def __init__(self):
30
+ pass
31
+
32
+ @classmethod
33
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
34
+ return [torch.float16, torch.bfloat16]
35
+
36
+ @classmethod
37
+ def get_min_capability(cls) -> int:
38
+ return 89
39
+
40
+ @classmethod
41
+ def get_name(self) -> str:
42
+ return "w8a8_fp8"
43
+
44
+ @classmethod
45
+ def get_config_filenames(cls) -> List[str]:
46
+ return []
47
+
48
+ @classmethod
49
+ def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config":
50
+ return cls()
51
+
52
+ def get_quant_method(
53
+ self,
54
+ layer: torch.nn.Module,
55
+ prefix: str,
56
+ ) -> Optional["QuantizeMethodBase"]:
57
+ from sglang.srt.layers.linear import LinearBase
58
+
59
+ if isinstance(layer, LinearBase):
60
+ return W8A8Fp8LinearMethod(self)
61
+ return None
62
+
63
+ def get_scaled_act_names(self) -> List[str]:
64
+ return []
65
+
66
+
67
+ class W8A8Fp8LinearMethod(LinearMethodBase):
68
+
69
+ def __init__(self, quantization_config: W8A8Fp8Config):
70
+ self.cutlass_fp8_supported = cutlass_fp8_supported()
71
+ self.quantization_config = quantization_config
72
+
73
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
74
+ weight = layer.weight
75
+ weight_scale = layer.weight_scale.detach()
76
+ if _is_hip:
77
+ weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
78
+ weight=weight, weight_scale=weight_scale
79
+ )
80
+ layer.weight = Parameter(weight.t(), requires_grad=False)
81
+ layer.weight_scale = Parameter(weight_scale, requires_grad=False)
82
+
83
+ def create_weights(
84
+ self,
85
+ layer: torch.nn.Module,
86
+ input_size_per_partition: int,
87
+ output_partition_sizes: List[int],
88
+ input_size: int,
89
+ output_size: int,
90
+ params_dtype: torch.dtype,
91
+ **extra_weight_attrs
92
+ ):
93
+
94
+ weight_loader = extra_weight_attrs.get("weight_loader")
95
+ self.logical_widths = output_partition_sizes
96
+
97
+ weight = ModelWeightParameter(
98
+ data=torch.empty(
99
+ sum(output_partition_sizes),
100
+ input_size_per_partition,
101
+ dtype=torch.float8_e4m3fn,
102
+ ),
103
+ input_dim=1,
104
+ output_dim=0,
105
+ weight_loader=weight_loader,
106
+ )
107
+ layer.register_parameter("weight", weight)
108
+
109
+ weight_scale = ChannelQuantScaleParameter(
110
+ data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
111
+ output_dim=0,
112
+ weight_loader=weight_loader,
113
+ )
114
+ layer.register_parameter("weight_scale", weight_scale)
115
+
116
+ def apply(
117
+ self,
118
+ layer: torch.nn.Module,
119
+ x: torch.Tensor,
120
+ bias: Optional[torch.Tensor] = None,
121
+ ):
122
+ return apply_fp8_linear(
123
+ x,
124
+ layer.weight,
125
+ layer.weight_scale,
126
+ bias=bias,
127
+ cutlass_fp8_supported=self.cutlass_fp8_supported,
128
+ )