sglang 0.4.2.post3__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sglang/check_env.py +1 -0
  2. sglang/global_config.py +2 -0
  3. sglang/srt/constrained/outlines_backend.py +4 -1
  4. sglang/srt/entrypoints/engine.py +2 -2
  5. sglang/srt/layers/attention/flashinfer_backend.py +265 -147
  6. sglang/srt/layers/attention/triton_backend.py +358 -72
  7. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  8. sglang/srt/layers/linear.py +12 -5
  9. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
  10. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  11. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +200 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +200 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +200 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +178 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +200 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +175 -0
  18. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -5
  19. sglang/srt/layers/moe/fused_moe_triton/layer.py +2 -0
  20. sglang/srt/layers/moe/topk.py +1 -1
  21. sglang/srt/layers/quantization/__init__.py +51 -5
  22. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  24. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
  25. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  26. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  31. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +29 -29
  32. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  33. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  34. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +33 -33
  35. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  36. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
  37. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  38. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +27 -27
  39. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  40. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
  41. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  48. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +24 -24
  49. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  50. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
  51. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  52. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +42 -42
  53. sglang/srt/layers/quantization/fp8_kernel.py +123 -17
  54. sglang/srt/layers/quantization/fp8_utils.py +33 -4
  55. sglang/srt/lora/backend/__init__.py +25 -5
  56. sglang/srt/lora/backend/base_backend.py +31 -9
  57. sglang/srt/lora/backend/flashinfer_backend.py +41 -4
  58. sglang/srt/lora/backend/triton_backend.py +34 -4
  59. sglang/srt/lora/layers.py +293 -0
  60. sglang/srt/lora/lora.py +101 -326
  61. sglang/srt/lora/lora_manager.py +101 -269
  62. sglang/srt/lora/mem_pool.py +174 -0
  63. sglang/srt/lora/triton_ops/__init__.py +7 -1
  64. sglang/srt/lora/triton_ops/gate_up_lora_b.py +170 -0
  65. sglang/srt/lora/triton_ops/qkv_lora_b.py +5 -5
  66. sglang/srt/lora/triton_ops/sgemm_lora_a.py +2 -2
  67. sglang/srt/lora/triton_ops/sgemm_lora_b.py +2 -2
  68. sglang/srt/lora/utils.py +141 -0
  69. sglang/srt/managers/detokenizer_manager.py +1 -0
  70. sglang/srt/managers/io_struct.py +4 -0
  71. sglang/srt/managers/schedule_batch.py +16 -3
  72. sglang/srt/managers/scheduler.py +29 -0
  73. sglang/srt/managers/tokenizer_manager.py +6 -0
  74. sglang/srt/managers/tp_worker_overlap_thread.py +4 -0
  75. sglang/srt/model_executor/cuda_graph_runner.py +16 -1
  76. sglang/srt/model_executor/model_runner.py +12 -2
  77. sglang/srt/models/deepseek_v2.py +17 -7
  78. sglang/srt/server_args.py +20 -1
  79. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  80. sglang/srt/speculative/eagle_utils.py +64 -21
  81. sglang/srt/speculative/eagle_worker.py +29 -8
  82. sglang/srt/utils.py +7 -0
  83. sglang/version.py +1 -1
  84. {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/METADATA +6 -5
  85. {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/RECORD +88 -55
  86. {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/LICENSE +0 -0
  87. {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/WHEEL +0 -0
  88. {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/top_level.txt +0 -0
@@ -1,61 +1,61 @@
1
1
  {
2
2
  "1": {
3
- "BLOCK_SIZE_M": 64,
4
- "BLOCK_SIZE_N": 16,
3
+ "BLOCK_SIZE_M": 32,
4
+ "BLOCK_SIZE_N": 32,
5
5
  "BLOCK_SIZE_K": 128,
6
- "GROUP_SIZE_M": 4,
6
+ "GROUP_SIZE_M": 8,
7
7
  "num_warps": 4,
8
8
  "num_stages": 2,
9
9
  "waves_per_eu": 0
10
10
  },
11
11
  "2": {
12
- "BLOCK_SIZE_M": 64,
13
- "BLOCK_SIZE_N": 16,
14
- "BLOCK_SIZE_K": 128,
15
- "GROUP_SIZE_M": 32,
12
+ "BLOCK_SIZE_M": 32,
13
+ "BLOCK_SIZE_N": 32,
14
+ "BLOCK_SIZE_K": 64,
15
+ "GROUP_SIZE_M": 8,
16
16
  "num_warps": 4,
17
17
  "num_stages": 2,
18
18
  "waves_per_eu": 0
19
19
  },
20
20
  "4": {
21
- "BLOCK_SIZE_M": 64,
22
- "BLOCK_SIZE_N": 16,
21
+ "BLOCK_SIZE_M": 32,
22
+ "BLOCK_SIZE_N": 32,
23
23
  "BLOCK_SIZE_K": 128,
24
- "GROUP_SIZE_M": 1,
24
+ "GROUP_SIZE_M": 32,
25
25
  "num_warps": 4,
26
26
  "num_stages": 2,
27
27
  "waves_per_eu": 0
28
28
  },
29
29
  "8": {
30
- "BLOCK_SIZE_M": 64,
31
- "BLOCK_SIZE_N": 16,
30
+ "BLOCK_SIZE_M": 32,
31
+ "BLOCK_SIZE_N": 64,
32
32
  "BLOCK_SIZE_K": 128,
33
- "GROUP_SIZE_M": 4,
33
+ "GROUP_SIZE_M": 16,
34
34
  "num_warps": 4,
35
35
  "num_stages": 2,
36
36
  "waves_per_eu": 0
37
37
  },
38
38
  "16": {
39
- "BLOCK_SIZE_M": 64,
40
- "BLOCK_SIZE_N": 16,
39
+ "BLOCK_SIZE_M": 32,
40
+ "BLOCK_SIZE_N": 32,
41
41
  "BLOCK_SIZE_K": 128,
42
- "GROUP_SIZE_M": 16,
42
+ "GROUP_SIZE_M": 8,
43
43
  "num_warps": 4,
44
44
  "num_stages": 2,
45
45
  "waves_per_eu": 0
46
46
  },
47
47
  "24": {
48
- "BLOCK_SIZE_M": 64,
49
- "BLOCK_SIZE_N": 16,
48
+ "BLOCK_SIZE_M": 32,
49
+ "BLOCK_SIZE_N": 32,
50
50
  "BLOCK_SIZE_K": 128,
51
- "GROUP_SIZE_M": 16,
51
+ "GROUP_SIZE_M": 8,
52
52
  "num_warps": 4,
53
53
  "num_stages": 2,
54
54
  "waves_per_eu": 0
55
55
  },
56
56
  "32": {
57
- "BLOCK_SIZE_M": 64,
58
- "BLOCK_SIZE_N": 16,
57
+ "BLOCK_SIZE_M": 32,
58
+ "BLOCK_SIZE_N": 32,
59
59
  "BLOCK_SIZE_K": 128,
60
60
  "GROUP_SIZE_M": 16,
61
61
  "num_warps": 4,
@@ -64,52 +64,52 @@
64
64
  },
65
65
  "48": {
66
66
  "BLOCK_SIZE_M": 64,
67
- "BLOCK_SIZE_N": 16,
67
+ "BLOCK_SIZE_N": 32,
68
68
  "BLOCK_SIZE_K": 128,
69
- "GROUP_SIZE_M": 16,
69
+ "GROUP_SIZE_M": 1,
70
70
  "num_warps": 4,
71
71
  "num_stages": 2,
72
72
  "waves_per_eu": 0
73
73
  },
74
74
  "64": {
75
75
  "BLOCK_SIZE_M": 64,
76
- "BLOCK_SIZE_N": 16,
76
+ "BLOCK_SIZE_N": 64,
77
77
  "BLOCK_SIZE_K": 128,
78
- "GROUP_SIZE_M": 16,
78
+ "GROUP_SIZE_M": 4,
79
79
  "num_warps": 4,
80
80
  "num_stages": 2,
81
81
  "waves_per_eu": 0
82
82
  },
83
83
  "96": {
84
- "BLOCK_SIZE_M": 64,
85
- "BLOCK_SIZE_N": 16,
84
+ "BLOCK_SIZE_M": 32,
85
+ "BLOCK_SIZE_N": 128,
86
86
  "BLOCK_SIZE_K": 128,
87
- "GROUP_SIZE_M": 1,
87
+ "GROUP_SIZE_M": 4,
88
88
  "num_warps": 4,
89
89
  "num_stages": 2,
90
90
  "waves_per_eu": 0
91
91
  },
92
92
  "128": {
93
- "BLOCK_SIZE_M": 64,
93
+ "BLOCK_SIZE_M": 128,
94
94
  "BLOCK_SIZE_N": 32,
95
95
  "BLOCK_SIZE_K": 128,
96
- "GROUP_SIZE_M": 1,
96
+ "GROUP_SIZE_M": 16,
97
97
  "num_warps": 4,
98
98
  "num_stages": 2,
99
99
  "waves_per_eu": 0
100
100
  },
101
101
  "256": {
102
102
  "BLOCK_SIZE_M": 64,
103
- "BLOCK_SIZE_N": 32,
103
+ "BLOCK_SIZE_N": 128,
104
104
  "BLOCK_SIZE_K": 128,
105
- "GROUP_SIZE_M": 1,
105
+ "GROUP_SIZE_M": 16,
106
106
  "num_warps": 4,
107
107
  "num_stages": 2,
108
108
  "waves_per_eu": 0
109
109
  },
110
110
  "512": {
111
- "BLOCK_SIZE_M": 128,
112
- "BLOCK_SIZE_N": 32,
111
+ "BLOCK_SIZE_M": 64,
112
+ "BLOCK_SIZE_N": 128,
113
113
  "BLOCK_SIZE_K": 128,
114
114
  "GROUP_SIZE_M": 32,
115
115
  "num_warps": 4,
@@ -117,28 +117,28 @@
117
117
  "waves_per_eu": 0
118
118
  },
119
119
  "1024": {
120
- "BLOCK_SIZE_M": 64,
121
- "BLOCK_SIZE_N": 64,
120
+ "BLOCK_SIZE_M": 32,
121
+ "BLOCK_SIZE_N": 128,
122
122
  "BLOCK_SIZE_K": 128,
123
- "GROUP_SIZE_M": 4,
123
+ "GROUP_SIZE_M": 8,
124
124
  "num_warps": 4,
125
125
  "num_stages": 2,
126
126
  "waves_per_eu": 0
127
127
  },
128
128
  "1536": {
129
129
  "BLOCK_SIZE_M": 64,
130
- "BLOCK_SIZE_N": 64,
130
+ "BLOCK_SIZE_N": 128,
131
131
  "BLOCK_SIZE_K": 128,
132
- "GROUP_SIZE_M": 1,
132
+ "GROUP_SIZE_M": 4,
133
133
  "num_warps": 4,
134
134
  "num_stages": 2,
135
135
  "waves_per_eu": 0
136
136
  },
137
137
  "2048": {
138
- "BLOCK_SIZE_M": 64,
138
+ "BLOCK_SIZE_M": 32,
139
139
  "BLOCK_SIZE_N": 128,
140
140
  "BLOCK_SIZE_K": 128,
141
- "GROUP_SIZE_M": 1,
141
+ "GROUP_SIZE_M": 4,
142
142
  "num_warps": 4,
143
143
  "num_stages": 2,
144
144
  "waves_per_eu": 0
@@ -156,7 +156,7 @@
156
156
  "BLOCK_SIZE_M": 64,
157
157
  "BLOCK_SIZE_N": 128,
158
158
  "BLOCK_SIZE_K": 128,
159
- "GROUP_SIZE_M": 1,
159
+ "GROUP_SIZE_M": 4,
160
160
  "num_warps": 4,
161
161
  "num_stages": 2,
162
162
  "waves_per_eu": 0
@@ -27,6 +27,10 @@ from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
27
27
  is_hip_ = is_hip()
28
28
  fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
29
29
 
30
+ _is_cuda = torch.cuda.is_available() and torch.version.cuda
31
+ if _is_cuda:
32
+ from sgl_kernel import sgl_per_token_group_quant_fp8
33
+
30
34
  logger = logging.getLogger(__name__)
31
35
 
32
36
 
@@ -72,11 +76,60 @@ def _per_token_group_quant_fp8(
72
76
  tl.store(y_s_ptr, y_s)
73
77
 
74
78
 
79
+ @triton.jit
80
+ def _per_token_group_quant_fp8_colmajor(
81
+ # Pointers to inputs and output
82
+ y_ptr,
83
+ y_q_ptr,
84
+ y_s_ptr,
85
+ group_size,
86
+ # Num columns of y
87
+ y_num_columns,
88
+ # Stride from one column to the next of y_s
89
+ y_s_col_stride,
90
+ # Avoid to divide zero
91
+ eps,
92
+ # Information for float8
93
+ fp8_min,
94
+ fp8_max,
95
+ # Meta-parameters
96
+ BLOCK: tl.constexpr,
97
+ ):
98
+ """A Triton-accelerated function to perform per-token-group
99
+ quantization on a tensor.
100
+ This function converts the tensor values into float8 values.
101
+ """
102
+ # Map the program id to the row of X and Y it should compute.
103
+ g_id = tl.program_id(0)
104
+ y_ptr += g_id * group_size
105
+ y_q_ptr += g_id * group_size
106
+
107
+ # Convert g_id the flattened block coordinate to 2D so we can index
108
+ # into the output y_scales matrix
109
+ blocks_per_row = y_num_columns // group_size
110
+ scale_col = g_id % blocks_per_row
111
+ scale_row = g_id // blocks_per_row
112
+ y_s_ptr += scale_col * y_s_col_stride + scale_row
113
+
114
+ cols = tl.arange(0, BLOCK) # group_size <= BLOCK
115
+ mask = cols < group_size
116
+
117
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
118
+ # Quant
119
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
120
+ y_s = _absmax / fp8_max
121
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
122
+
123
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
124
+ tl.store(y_s_ptr, y_s)
125
+
126
+
75
127
  def per_token_group_quant_fp8(
76
128
  x: torch.Tensor,
77
129
  group_size: int,
78
130
  eps: float = 1e-10,
79
131
  dtype: torch.dtype = fp8_type_,
132
+ column_major_scales: bool = False,
80
133
  ) -> Tuple[torch.Tensor, torch.Tensor]:
81
134
  """Function to perform per-token-group quantization on an input tensor `x`.
82
135
 
@@ -108,30 +161,83 @@ def per_token_group_quant_fp8(
108
161
  x_q = torch.empty_like(x, device=x.device, dtype=dtype)
109
162
  M = x.numel() // group_size
110
163
  N = group_size
111
- x_s = torch.empty(
112
- x.shape[:-1] + (x.shape[-1] // group_size,),
113
- device=x.device,
114
- dtype=torch.float32,
115
- )
164
+ if column_major_scales:
165
+ x_s = torch.empty(
166
+ (x.shape[-1] // group_size,) + x.shape[:-1],
167
+ device=x.device,
168
+ dtype=torch.float32,
169
+ ).permute(-1, -2)
170
+ else:
171
+ x_s = torch.empty(
172
+ x.shape[:-1] + (x.shape[-1] // group_size,),
173
+ device=x.device,
174
+ dtype=torch.float32,
175
+ )
116
176
 
117
177
  BLOCK = triton.next_power_of_2(N)
118
178
  # heuristics for number of warps
119
179
  num_warps = min(max(BLOCK // 256, 1), 8)
120
180
  num_stages = 1
121
- _per_token_group_quant_fp8[(M,)](
122
- x,
123
- x_q,
124
- x_s,
125
- group_size,
126
- N,
127
- eps,
128
- fp8_min=fp8_min,
129
- fp8_max=fp8_max,
130
- BLOCK=BLOCK,
131
- num_warps=num_warps,
132
- num_stages=num_stages,
181
+ if column_major_scales:
182
+ _per_token_group_quant_fp8_colmajor[(M,)](
183
+ x,
184
+ x_q,
185
+ x_s,
186
+ group_size,
187
+ x.shape[1],
188
+ x_s.stride(1),
189
+ eps,
190
+ fp8_min=fp8_min,
191
+ fp8_max=fp8_max,
192
+ BLOCK=BLOCK,
193
+ num_warps=num_warps,
194
+ num_stages=num_stages,
195
+ )
196
+ else:
197
+ _per_token_group_quant_fp8[(M,)](
198
+ x,
199
+ x_q,
200
+ x_s,
201
+ group_size,
202
+ N,
203
+ eps,
204
+ fp8_min=fp8_min,
205
+ fp8_max=fp8_max,
206
+ BLOCK=BLOCK,
207
+ num_warps=num_warps,
208
+ num_stages=num_stages,
209
+ )
210
+
211
+ return x_q, x_s
212
+
213
+
214
+ def sglang_per_token_group_quant_fp8(
215
+ x: torch.Tensor,
216
+ group_size: int,
217
+ eps: float = 1e-10,
218
+ dtype: torch.dtype = fp8_type_,
219
+ ):
220
+ assert (
221
+ x.shape[-1] % group_size == 0
222
+ ), "the last dimension of `x` cannot be divisible by `group_size`"
223
+ assert x.is_contiguous(), "`x` is not contiguous"
224
+
225
+ finfo = torch.finfo(dtype)
226
+ fp8_max = finfo.max
227
+
228
+ fp8_min = -fp8_max
229
+
230
+ x_q = torch.empty_like(x, device=x.device, dtype=dtype)
231
+ M = x.numel() // group_size
232
+ N = group_size
233
+ x_s = torch.empty(
234
+ x.shape[:-1] + (x.shape[-1] // group_size,),
235
+ device=x.device,
236
+ dtype=torch.float32,
133
237
  )
134
238
 
239
+ sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
240
+
135
241
  return x_q, x_s
136
242
 
137
243
 
@@ -10,6 +10,9 @@ from sglang.srt.layers.quantization.fp8_kernel import (
10
10
  from sglang.srt.utils import is_hip
11
11
 
12
12
  is_hip_ = is_hip()
13
+ _is_cuda = torch.cuda.is_available() and torch.version.cuda
14
+ if _is_cuda:
15
+ from sgl_kernel import fp8_blockwise_scaled_mm
13
16
 
14
17
 
15
18
  def normalize_e4m3fn_to_e4m3fnuz(
@@ -36,6 +39,19 @@ def normalize_e4m3fn_to_e4m3fnuz(
36
39
  return weight, weight_scale, input_scale
37
40
 
38
41
 
42
+ def cutlass_block_fp8_supported() -> bool:
43
+ if _is_cuda:
44
+ major, minor = torch.cuda.get_device_capability()
45
+ sm_version = major * 10 + minor
46
+ cuda_version = tuple(map(int, torch.version.cuda.split(".")))
47
+ if cuda_version >= (12, 0) and sm_version >= 90:
48
+ return True
49
+ return False
50
+
51
+
52
+ CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()
53
+
54
+
39
55
  def apply_w8a8_block_fp8_linear(
40
56
  input: torch.Tensor,
41
57
  weight: torch.Tensor,
@@ -48,11 +64,24 @@ def apply_w8a8_block_fp8_linear(
48
64
  # View input as 2D matrix for fp8 methods
49
65
  input_2d = input.view(-1, input.shape[-1])
50
66
  output_shape = [*input.shape[:-1], weight.shape[0]]
51
-
52
- q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1])
53
- output = w8a8_block_fp8_matmul(
54
- q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
67
+ # TODO: add more robust shape check here
68
+ shape_supported_by_cutlass = (
69
+ weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0
55
70
  )
71
+ if CUTLASS_BLOCK_FP8_SUPPORTED and shape_supported_by_cutlass:
72
+ q_input, x_scale = per_token_group_quant_fp8(
73
+ input_2d, block_size[1], column_major_scales=True
74
+ )
75
+ output = fp8_blockwise_scaled_mm(
76
+ q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
77
+ )
78
+ else:
79
+ q_input, x_scale = per_token_group_quant_fp8(
80
+ input_2d, block_size[1], column_major_scales=False
81
+ )
82
+ output = w8a8_block_fp8_matmul(
83
+ q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
84
+ )
56
85
 
57
86
  if bias is not None:
58
87
  output = output + bias
@@ -1,8 +1,28 @@
1
- from .base_backend import BaseLoraBackend
2
- from .flashinfer_backend import FlashInferLoraBackend
3
- from .triton_backend import TritonLoraBackend
1
+ from .base_backend import BaseLoRABackend
2
+ from .flashinfer_backend import FlashInferLoRABackend
3
+ from .triton_backend import TritonLoRABackend
4
+
5
+
6
+ def get_backend_from_name(name: str) -> BaseLoRABackend:
7
+ """
8
+ Get corresponding backend class from backend's name
9
+ """
10
+ backend_mapping = {
11
+ "triton": TritonLoRABackend,
12
+ "flashinfer": FlashInferLoRABackend,
13
+ }
14
+
15
+ if name in backend_mapping:
16
+ return backend_mapping[name]
17
+
18
+ raise Exception(
19
+ f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}"
20
+ )
21
+
4
22
 
5
23
  __all__ = [
6
- "FlashInferLoraBackend",
7
- "TritonLoraBackend",
24
+ "BaseLoRABackend",
25
+ "FlashInferLoRABackend",
26
+ "TritonLoRABackend",
27
+ "get_backend_from_name",
8
28
  ]
@@ -2,7 +2,7 @@ from typing import Tuple, Union
2
2
 
3
3
  import torch
4
4
 
5
- from sglang.srt.lora.lora import LoraBatchInfo
5
+ from sglang.srt.lora.utils import LoRABatchInfo
6
6
 
7
7
 
8
8
  def get_fuse_output_scaling_add_from_name(name: str) -> bool:
@@ -13,7 +13,7 @@ def get_fuse_output_scaling_add_from_name(name: str) -> bool:
13
13
  return mapping.get(name, False)
14
14
 
15
15
 
16
- def get_fuse_qkv_lora_b_from_name(name: str) -> bool:
16
+ def get_fuse_stacked_lora_b_from_name(name: str) -> bool:
17
17
  mapping = {
18
18
  "triton": True,
19
19
  "flashinfer": False,
@@ -21,7 +21,7 @@ def get_fuse_qkv_lora_b_from_name(name: str) -> bool:
21
21
  return mapping.get(name, False)
22
22
 
23
23
 
24
- class BaseLoraBackend:
24
+ class BaseLoRABackend:
25
25
  """Base class for different Lora backends.
26
26
  Each backend has its own implementation of Lora kernels.
27
27
 
@@ -32,11 +32,11 @@ class BaseLoraBackend:
32
32
  and the operation of scaling and adding will be fused into kernel
33
33
  """
34
34
 
35
- def __init__(self, name: str, batch_info: LoraBatchInfo = None):
35
+ def __init__(self, name: str, batch_info: LoRABatchInfo = None):
36
36
  self.name = name
37
37
  self.batch_info = batch_info
38
38
  self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name)
39
- self.fuse_qkv_lora_b = get_fuse_qkv_lora_b_from_name(name)
39
+ self.fuse_stacked_lora_b = get_fuse_stacked_lora_b_from_name(name)
40
40
 
41
41
  def run_lora_a_sgemm(
42
42
  self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
@@ -46,10 +46,11 @@ class BaseLoraBackend:
46
46
 
47
47
  Args:
48
48
  x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
49
- weights: a set of lora weights with shape (num_lora, r, input_dim), here r is lora rank
49
+ weights: a set of lora weights with shape (num_lora, c * r, input_dim),
50
+ here r is lora rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
50
51
  usually input_dim is much larger than r
51
52
  Returns:
52
- result with shape (s, r)
53
+ result with shape (s, c * r)
53
54
  """
54
55
  pass
55
56
 
@@ -83,7 +84,7 @@ class BaseLoraBackend:
83
84
  qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
84
85
  qkv_lora_b: lora_b module for qkv.
85
86
  If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
86
- If passed in as a tuple of two tensors containing:
87
+ If passed in as a tuple of two tensors, it should contain:
87
88
  a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
88
89
  and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
89
90
  Returns:
@@ -91,5 +92,26 @@ class BaseLoraBackend:
91
92
  """
92
93
  pass
93
94
 
94
- def set_batch_info(self, batch_info: LoraBatchInfo):
95
+ def run_gate_up_lora(
96
+ self,
97
+ x: torch.Tensor,
98
+ gate_up_lora_a: torch.Tensor,
99
+ gate_up_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
100
+ *args,
101
+ **kwargs
102
+ ) -> torch.Tensor:
103
+ """Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
104
+
105
+ Args:
106
+ x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
107
+ gate_up_lora_a: lora_a module for gate_up_proj, with shape (num_lora, 2 * r, input_dim)
108
+ gate_up_lora_b: lora_b module for qkv.
109
+ If passed in as a tensor, its shape should be (num_lora, 2 * output_dim, r)
110
+ If passed in as a tuple, it should contain two tensors with shape (num_lora, output_dim, r)
111
+ Returns:
112
+ result with shape (s, 2 * output_dim)
113
+ """
114
+ pass
115
+
116
+ def set_batch_info(self, batch_info: LoRABatchInfo):
95
117
  self.batch_info = batch_info
@@ -2,17 +2,17 @@ from typing import Tuple
2
2
 
3
3
  import torch
4
4
 
5
- from sglang.srt.lora.backend import BaseLoraBackend
6
- from sglang.srt.lora.lora import LoraBatchInfo
5
+ from sglang.srt.lora.backend import BaseLoRABackend
6
+ from sglang.srt.lora.utils import LoRABatchInfo
7
7
  from sglang.srt.utils import is_flashinfer_available
8
8
 
9
9
  if is_flashinfer_available():
10
10
  from flashinfer import SegmentGEMMWrapper
11
11
 
12
12
 
13
- class FlashInferLoraBackend(BaseLoraBackend):
13
+ class FlashInferLoRABackend(BaseLoRABackend):
14
14
 
15
- def __init__(self, name: str, batch_info: LoraBatchInfo = None):
15
+ def __init__(self, name: str, batch_info: LoRABatchInfo = None):
16
16
  super().__init__(name, batch_info)
17
17
 
18
18
  # Set up SGemm Wrapper from flashinfer
@@ -55,6 +55,8 @@ class FlashInferLoraBackend(BaseLoraBackend):
55
55
  **kwargs,
56
56
  ) -> torch.Tensor:
57
57
 
58
+ assert isinstance(qkv_lora_b, tuple) and len(qkv_lora_b) == 2
59
+
58
60
  # Shape of lora_a_output: (s, 3 * r)
59
61
  lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a)
60
62
 
@@ -89,3 +91,38 @@ class FlashInferLoraBackend(BaseLoraBackend):
89
91
  )
90
92
 
91
93
  return lora_output
94
+
95
+ def run_gate_up_lora(
96
+ self,
97
+ x: torch.Tensor,
98
+ gate_up_lora_a: torch.Tensor,
99
+ gate_up_lora_b: Tuple[torch.Tensor],
100
+ *args,
101
+ **kwargs,
102
+ ) -> torch.Tensor:
103
+
104
+ assert isinstance(gate_up_lora_b, tuple) and len(gate_up_lora_b) == 2
105
+ lora_rank = gate_up_lora_b[0].shape[-1]
106
+ output_dim = gate_up_lora_b[0].shape[-2]
107
+
108
+ # Shape of lora_a_output: (s, 2 * r)
109
+ lora_a_output = self.run_lora_a_sgemm(x=x, weights=gate_up_lora_a)
110
+
111
+ lora_output = torch.empty(
112
+ (x.shape[0], 2 * output_dim),
113
+ device=x.device,
114
+ dtype=x.dtype,
115
+ )
116
+
117
+ # Compute lora for gate and up proj respectively
118
+ lora_output[:, :output_dim] = self.run_lora_b_sgemm(
119
+ x=lora_a_output[:, :lora_rank].contiguous(),
120
+ weights=gate_up_lora_b[0],
121
+ )
122
+
123
+ lora_output[:, output_dim:] = self.run_lora_b_sgemm(
124
+ x=lora_a_output[:, lora_rank:].contiguous(),
125
+ weights=gate_up_lora_b[1],
126
+ )
127
+
128
+ return lora_output
@@ -1,17 +1,18 @@
1
1
  import torch
2
2
 
3
- from sglang.srt.lora.backend import BaseLoraBackend
4
- from sglang.srt.lora.lora import LoraBatchInfo
3
+ from sglang.srt.lora.backend import BaseLoRABackend
5
4
  from sglang.srt.lora.triton_ops import (
5
+ gate_up_lora_b_fwd,
6
6
  qkv_lora_b_fwd,
7
7
  sgemm_lora_a_fwd,
8
8
  sgemm_lora_b_fwd,
9
9
  )
10
+ from sglang.srt.lora.utils import LoRABatchInfo
10
11
 
11
12
 
12
- class TritonLoraBackend(BaseLoraBackend):
13
+ class TritonLoRABackend(BaseLoRABackend):
13
14
 
14
- def __init__(self, name: str, batch_info: LoraBatchInfo = None):
15
+ def __init__(self, name: str, batch_info: LoRABatchInfo = None):
15
16
  super().__init__(name, batch_info)
16
17
 
17
18
  def run_lora_a_sgemm(
@@ -59,3 +60,32 @@ class TritonLoraBackend(BaseLoraBackend):
59
60
  scaling,
60
61
  )
61
62
  return lora_output
63
+
64
+ def run_gate_up_lora(
65
+ self,
66
+ x: torch.Tensor,
67
+ gate_up_lora_a: torch.Tensor,
68
+ gate_up_lora_b: torch.Tensor,
69
+ base_output: torch.Tensor = None,
70
+ scaling: float = 1.0,
71
+ *args,
72
+ **kwargs
73
+ ) -> torch.Tensor:
74
+
75
+ # x: (s, input_dim)
76
+ # gate_up_lora_a: (num_lora, 2 * r, input_dim)
77
+ # gate_up_lora_b: (num_lora, 2 * output_dim, r)
78
+ assert isinstance(gate_up_lora_b, torch.Tensor)
79
+ output_dim = gate_up_lora_b.shape[-2] // 2
80
+
81
+ # lora_a_output: (s, 2 * r)
82
+ lora_a_output = sgemm_lora_a_fwd(x, gate_up_lora_a, self.batch_info)
83
+ lora_output = gate_up_lora_b_fwd(
84
+ lora_a_output,
85
+ gate_up_lora_b,
86
+ self.batch_info,
87
+ output_dim,
88
+ base_output,
89
+ scaling,
90
+ )
91
+ return lora_output