sglang 0.4.2.post4__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. sglang/global_config.py +2 -0
  2. sglang/srt/entrypoints/engine.py +2 -2
  3. sglang/srt/layers/attention/flashinfer_backend.py +235 -110
  4. sglang/srt/layers/attention/triton_backend.py +358 -72
  5. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  6. sglang/srt/layers/linear.py +12 -5
  7. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
  8. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  9. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
  10. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +200 -0
  11. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +200 -0
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +200 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +178 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +200 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +175 -0
  16. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -2
  17. sglang/srt/layers/moe/fused_moe_triton/layer.py +2 -0
  18. sglang/srt/layers/moe/topk.py +1 -1
  19. sglang/srt/layers/quantization/__init__.py +51 -5
  20. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  21. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
  22. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  23. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +29 -29
  24. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  25. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +33 -33
  26. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  27. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
  28. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  29. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +27 -27
  30. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  31. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
  32. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  33. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +24 -24
  34. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  35. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
  36. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  37. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +42 -42
  38. sglang/srt/layers/quantization/fp8_kernel.py +123 -17
  39. sglang/srt/layers/quantization/fp8_utils.py +33 -4
  40. sglang/srt/managers/detokenizer_manager.py +1 -0
  41. sglang/srt/managers/io_struct.py +4 -0
  42. sglang/srt/managers/schedule_batch.py +16 -3
  43. sglang/srt/managers/scheduler.py +29 -0
  44. sglang/srt/managers/tokenizer_manager.py +6 -0
  45. sglang/srt/managers/tp_worker_overlap_thread.py +4 -0
  46. sglang/srt/model_executor/cuda_graph_runner.py +12 -1
  47. sglang/srt/model_executor/model_runner.py +12 -2
  48. sglang/srt/models/deepseek_v2.py +17 -7
  49. sglang/srt/server_args.py +20 -1
  50. sglang/srt/speculative/eagle_worker.py +28 -8
  51. sglang/srt/utils.py +7 -0
  52. sglang/version.py +1 -1
  53. {sglang-0.4.2.post4.dist-info → sglang-0.4.3.dist-info}/METADATA +4 -3
  54. {sglang-0.4.2.post4.dist-info → sglang-0.4.3.dist-info}/RECORD +57 -41
  55. {sglang-0.4.2.post4.dist-info → sglang-0.4.3.dist-info}/LICENSE +0 -0
  56. {sglang-0.4.2.post4.dist-info → sglang-0.4.3.dist-info}/WHEEL +0 -0
  57. {sglang-0.4.2.post4.dist-info → sglang-0.4.3.dist-info}/top_level.txt +0 -0
@@ -1,61 +1,61 @@
1
1
  {
2
2
  "1": {
3
- "BLOCK_SIZE_M": 64,
4
- "BLOCK_SIZE_N": 16,
3
+ "BLOCK_SIZE_M": 32,
4
+ "BLOCK_SIZE_N": 32,
5
5
  "BLOCK_SIZE_K": 128,
6
- "GROUP_SIZE_M": 4,
6
+ "GROUP_SIZE_M": 8,
7
7
  "num_warps": 4,
8
8
  "num_stages": 2,
9
9
  "waves_per_eu": 0
10
10
  },
11
11
  "2": {
12
- "BLOCK_SIZE_M": 64,
13
- "BLOCK_SIZE_N": 16,
14
- "BLOCK_SIZE_K": 128,
15
- "GROUP_SIZE_M": 32,
12
+ "BLOCK_SIZE_M": 32,
13
+ "BLOCK_SIZE_N": 32,
14
+ "BLOCK_SIZE_K": 64,
15
+ "GROUP_SIZE_M": 8,
16
16
  "num_warps": 4,
17
17
  "num_stages": 2,
18
18
  "waves_per_eu": 0
19
19
  },
20
20
  "4": {
21
- "BLOCK_SIZE_M": 64,
22
- "BLOCK_SIZE_N": 16,
21
+ "BLOCK_SIZE_M": 32,
22
+ "BLOCK_SIZE_N": 32,
23
23
  "BLOCK_SIZE_K": 128,
24
- "GROUP_SIZE_M": 1,
24
+ "GROUP_SIZE_M": 32,
25
25
  "num_warps": 4,
26
26
  "num_stages": 2,
27
27
  "waves_per_eu": 0
28
28
  },
29
29
  "8": {
30
- "BLOCK_SIZE_M": 64,
31
- "BLOCK_SIZE_N": 16,
30
+ "BLOCK_SIZE_M": 32,
31
+ "BLOCK_SIZE_N": 64,
32
32
  "BLOCK_SIZE_K": 128,
33
- "GROUP_SIZE_M": 4,
33
+ "GROUP_SIZE_M": 16,
34
34
  "num_warps": 4,
35
35
  "num_stages": 2,
36
36
  "waves_per_eu": 0
37
37
  },
38
38
  "16": {
39
- "BLOCK_SIZE_M": 64,
40
- "BLOCK_SIZE_N": 16,
39
+ "BLOCK_SIZE_M": 32,
40
+ "BLOCK_SIZE_N": 32,
41
41
  "BLOCK_SIZE_K": 128,
42
- "GROUP_SIZE_M": 16,
42
+ "GROUP_SIZE_M": 8,
43
43
  "num_warps": 4,
44
44
  "num_stages": 2,
45
45
  "waves_per_eu": 0
46
46
  },
47
47
  "24": {
48
- "BLOCK_SIZE_M": 64,
49
- "BLOCK_SIZE_N": 16,
48
+ "BLOCK_SIZE_M": 32,
49
+ "BLOCK_SIZE_N": 32,
50
50
  "BLOCK_SIZE_K": 128,
51
- "GROUP_SIZE_M": 16,
51
+ "GROUP_SIZE_M": 8,
52
52
  "num_warps": 4,
53
53
  "num_stages": 2,
54
54
  "waves_per_eu": 0
55
55
  },
56
56
  "32": {
57
- "BLOCK_SIZE_M": 64,
58
- "BLOCK_SIZE_N": 16,
57
+ "BLOCK_SIZE_M": 32,
58
+ "BLOCK_SIZE_N": 32,
59
59
  "BLOCK_SIZE_K": 128,
60
60
  "GROUP_SIZE_M": 16,
61
61
  "num_warps": 4,
@@ -64,52 +64,52 @@
64
64
  },
65
65
  "48": {
66
66
  "BLOCK_SIZE_M": 64,
67
- "BLOCK_SIZE_N": 16,
67
+ "BLOCK_SIZE_N": 32,
68
68
  "BLOCK_SIZE_K": 128,
69
- "GROUP_SIZE_M": 16,
69
+ "GROUP_SIZE_M": 1,
70
70
  "num_warps": 4,
71
71
  "num_stages": 2,
72
72
  "waves_per_eu": 0
73
73
  },
74
74
  "64": {
75
75
  "BLOCK_SIZE_M": 64,
76
- "BLOCK_SIZE_N": 16,
76
+ "BLOCK_SIZE_N": 64,
77
77
  "BLOCK_SIZE_K": 128,
78
- "GROUP_SIZE_M": 16,
78
+ "GROUP_SIZE_M": 4,
79
79
  "num_warps": 4,
80
80
  "num_stages": 2,
81
81
  "waves_per_eu": 0
82
82
  },
83
83
  "96": {
84
- "BLOCK_SIZE_M": 64,
85
- "BLOCK_SIZE_N": 16,
84
+ "BLOCK_SIZE_M": 32,
85
+ "BLOCK_SIZE_N": 128,
86
86
  "BLOCK_SIZE_K": 128,
87
- "GROUP_SIZE_M": 1,
87
+ "GROUP_SIZE_M": 4,
88
88
  "num_warps": 4,
89
89
  "num_stages": 2,
90
90
  "waves_per_eu": 0
91
91
  },
92
92
  "128": {
93
- "BLOCK_SIZE_M": 64,
93
+ "BLOCK_SIZE_M": 128,
94
94
  "BLOCK_SIZE_N": 32,
95
95
  "BLOCK_SIZE_K": 128,
96
- "GROUP_SIZE_M": 1,
96
+ "GROUP_SIZE_M": 16,
97
97
  "num_warps": 4,
98
98
  "num_stages": 2,
99
99
  "waves_per_eu": 0
100
100
  },
101
101
  "256": {
102
102
  "BLOCK_SIZE_M": 64,
103
- "BLOCK_SIZE_N": 32,
103
+ "BLOCK_SIZE_N": 128,
104
104
  "BLOCK_SIZE_K": 128,
105
- "GROUP_SIZE_M": 1,
105
+ "GROUP_SIZE_M": 16,
106
106
  "num_warps": 4,
107
107
  "num_stages": 2,
108
108
  "waves_per_eu": 0
109
109
  },
110
110
  "512": {
111
- "BLOCK_SIZE_M": 128,
112
- "BLOCK_SIZE_N": 32,
111
+ "BLOCK_SIZE_M": 64,
112
+ "BLOCK_SIZE_N": 128,
113
113
  "BLOCK_SIZE_K": 128,
114
114
  "GROUP_SIZE_M": 32,
115
115
  "num_warps": 4,
@@ -117,28 +117,28 @@
117
117
  "waves_per_eu": 0
118
118
  },
119
119
  "1024": {
120
- "BLOCK_SIZE_M": 64,
121
- "BLOCK_SIZE_N": 64,
120
+ "BLOCK_SIZE_M": 32,
121
+ "BLOCK_SIZE_N": 128,
122
122
  "BLOCK_SIZE_K": 128,
123
- "GROUP_SIZE_M": 4,
123
+ "GROUP_SIZE_M": 8,
124
124
  "num_warps": 4,
125
125
  "num_stages": 2,
126
126
  "waves_per_eu": 0
127
127
  },
128
128
  "1536": {
129
129
  "BLOCK_SIZE_M": 64,
130
- "BLOCK_SIZE_N": 64,
130
+ "BLOCK_SIZE_N": 128,
131
131
  "BLOCK_SIZE_K": 128,
132
- "GROUP_SIZE_M": 1,
132
+ "GROUP_SIZE_M": 4,
133
133
  "num_warps": 4,
134
134
  "num_stages": 2,
135
135
  "waves_per_eu": 0
136
136
  },
137
137
  "2048": {
138
- "BLOCK_SIZE_M": 64,
138
+ "BLOCK_SIZE_M": 32,
139
139
  "BLOCK_SIZE_N": 128,
140
140
  "BLOCK_SIZE_K": 128,
141
- "GROUP_SIZE_M": 1,
141
+ "GROUP_SIZE_M": 4,
142
142
  "num_warps": 4,
143
143
  "num_stages": 2,
144
144
  "waves_per_eu": 0
@@ -156,7 +156,7 @@
156
156
  "BLOCK_SIZE_M": 64,
157
157
  "BLOCK_SIZE_N": 128,
158
158
  "BLOCK_SIZE_K": 128,
159
- "GROUP_SIZE_M": 1,
159
+ "GROUP_SIZE_M": 4,
160
160
  "num_warps": 4,
161
161
  "num_stages": 2,
162
162
  "waves_per_eu": 0
@@ -27,6 +27,10 @@ from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
27
27
  is_hip_ = is_hip()
28
28
  fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
29
29
 
30
+ _is_cuda = torch.cuda.is_available() and torch.version.cuda
31
+ if _is_cuda:
32
+ from sgl_kernel import sgl_per_token_group_quant_fp8
33
+
30
34
  logger = logging.getLogger(__name__)
31
35
 
32
36
 
@@ -72,11 +76,60 @@ def _per_token_group_quant_fp8(
72
76
  tl.store(y_s_ptr, y_s)
73
77
 
74
78
 
79
+ @triton.jit
80
+ def _per_token_group_quant_fp8_colmajor(
81
+ # Pointers to inputs and output
82
+ y_ptr,
83
+ y_q_ptr,
84
+ y_s_ptr,
85
+ group_size,
86
+ # Num columns of y
87
+ y_num_columns,
88
+ # Stride from one column to the next of y_s
89
+ y_s_col_stride,
90
+ # Avoid to divide zero
91
+ eps,
92
+ # Information for float8
93
+ fp8_min,
94
+ fp8_max,
95
+ # Meta-parameters
96
+ BLOCK: tl.constexpr,
97
+ ):
98
+ """A Triton-accelerated function to perform per-token-group
99
+ quantization on a tensor.
100
+ This function converts the tensor values into float8 values.
101
+ """
102
+ # Map the program id to the row of X and Y it should compute.
103
+ g_id = tl.program_id(0)
104
+ y_ptr += g_id * group_size
105
+ y_q_ptr += g_id * group_size
106
+
107
+ # Convert g_id the flattened block coordinate to 2D so we can index
108
+ # into the output y_scales matrix
109
+ blocks_per_row = y_num_columns // group_size
110
+ scale_col = g_id % blocks_per_row
111
+ scale_row = g_id // blocks_per_row
112
+ y_s_ptr += scale_col * y_s_col_stride + scale_row
113
+
114
+ cols = tl.arange(0, BLOCK) # group_size <= BLOCK
115
+ mask = cols < group_size
116
+
117
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
118
+ # Quant
119
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
120
+ y_s = _absmax / fp8_max
121
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
122
+
123
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
124
+ tl.store(y_s_ptr, y_s)
125
+
126
+
75
127
  def per_token_group_quant_fp8(
76
128
  x: torch.Tensor,
77
129
  group_size: int,
78
130
  eps: float = 1e-10,
79
131
  dtype: torch.dtype = fp8_type_,
132
+ column_major_scales: bool = False,
80
133
  ) -> Tuple[torch.Tensor, torch.Tensor]:
81
134
  """Function to perform per-token-group quantization on an input tensor `x`.
82
135
 
@@ -108,30 +161,83 @@ def per_token_group_quant_fp8(
108
161
  x_q = torch.empty_like(x, device=x.device, dtype=dtype)
109
162
  M = x.numel() // group_size
110
163
  N = group_size
111
- x_s = torch.empty(
112
- x.shape[:-1] + (x.shape[-1] // group_size,),
113
- device=x.device,
114
- dtype=torch.float32,
115
- )
164
+ if column_major_scales:
165
+ x_s = torch.empty(
166
+ (x.shape[-1] // group_size,) + x.shape[:-1],
167
+ device=x.device,
168
+ dtype=torch.float32,
169
+ ).permute(-1, -2)
170
+ else:
171
+ x_s = torch.empty(
172
+ x.shape[:-1] + (x.shape[-1] // group_size,),
173
+ device=x.device,
174
+ dtype=torch.float32,
175
+ )
116
176
 
117
177
  BLOCK = triton.next_power_of_2(N)
118
178
  # heuristics for number of warps
119
179
  num_warps = min(max(BLOCK // 256, 1), 8)
120
180
  num_stages = 1
121
- _per_token_group_quant_fp8[(M,)](
122
- x,
123
- x_q,
124
- x_s,
125
- group_size,
126
- N,
127
- eps,
128
- fp8_min=fp8_min,
129
- fp8_max=fp8_max,
130
- BLOCK=BLOCK,
131
- num_warps=num_warps,
132
- num_stages=num_stages,
181
+ if column_major_scales:
182
+ _per_token_group_quant_fp8_colmajor[(M,)](
183
+ x,
184
+ x_q,
185
+ x_s,
186
+ group_size,
187
+ x.shape[1],
188
+ x_s.stride(1),
189
+ eps,
190
+ fp8_min=fp8_min,
191
+ fp8_max=fp8_max,
192
+ BLOCK=BLOCK,
193
+ num_warps=num_warps,
194
+ num_stages=num_stages,
195
+ )
196
+ else:
197
+ _per_token_group_quant_fp8[(M,)](
198
+ x,
199
+ x_q,
200
+ x_s,
201
+ group_size,
202
+ N,
203
+ eps,
204
+ fp8_min=fp8_min,
205
+ fp8_max=fp8_max,
206
+ BLOCK=BLOCK,
207
+ num_warps=num_warps,
208
+ num_stages=num_stages,
209
+ )
210
+
211
+ return x_q, x_s
212
+
213
+
214
+ def sglang_per_token_group_quant_fp8(
215
+ x: torch.Tensor,
216
+ group_size: int,
217
+ eps: float = 1e-10,
218
+ dtype: torch.dtype = fp8_type_,
219
+ ):
220
+ assert (
221
+ x.shape[-1] % group_size == 0
222
+ ), "the last dimension of `x` cannot be divisible by `group_size`"
223
+ assert x.is_contiguous(), "`x` is not contiguous"
224
+
225
+ finfo = torch.finfo(dtype)
226
+ fp8_max = finfo.max
227
+
228
+ fp8_min = -fp8_max
229
+
230
+ x_q = torch.empty_like(x, device=x.device, dtype=dtype)
231
+ M = x.numel() // group_size
232
+ N = group_size
233
+ x_s = torch.empty(
234
+ x.shape[:-1] + (x.shape[-1] // group_size,),
235
+ device=x.device,
236
+ dtype=torch.float32,
133
237
  )
134
238
 
239
+ sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
240
+
135
241
  return x_q, x_s
136
242
 
137
243
 
@@ -10,6 +10,9 @@ from sglang.srt.layers.quantization.fp8_kernel import (
10
10
  from sglang.srt.utils import is_hip
11
11
 
12
12
  is_hip_ = is_hip()
13
+ _is_cuda = torch.cuda.is_available() and torch.version.cuda
14
+ if _is_cuda:
15
+ from sgl_kernel import fp8_blockwise_scaled_mm
13
16
 
14
17
 
15
18
  def normalize_e4m3fn_to_e4m3fnuz(
@@ -36,6 +39,19 @@ def normalize_e4m3fn_to_e4m3fnuz(
36
39
  return weight, weight_scale, input_scale
37
40
 
38
41
 
42
+ def cutlass_block_fp8_supported() -> bool:
43
+ if _is_cuda:
44
+ major, minor = torch.cuda.get_device_capability()
45
+ sm_version = major * 10 + minor
46
+ cuda_version = tuple(map(int, torch.version.cuda.split(".")))
47
+ if cuda_version >= (12, 0) and sm_version >= 90:
48
+ return True
49
+ return False
50
+
51
+
52
+ CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()
53
+
54
+
39
55
  def apply_w8a8_block_fp8_linear(
40
56
  input: torch.Tensor,
41
57
  weight: torch.Tensor,
@@ -48,11 +64,24 @@ def apply_w8a8_block_fp8_linear(
48
64
  # View input as 2D matrix for fp8 methods
49
65
  input_2d = input.view(-1, input.shape[-1])
50
66
  output_shape = [*input.shape[:-1], weight.shape[0]]
51
-
52
- q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1])
53
- output = w8a8_block_fp8_matmul(
54
- q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
67
+ # TODO: add more robust shape check here
68
+ shape_supported_by_cutlass = (
69
+ weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0
55
70
  )
71
+ if CUTLASS_BLOCK_FP8_SUPPORTED and shape_supported_by_cutlass:
72
+ q_input, x_scale = per_token_group_quant_fp8(
73
+ input_2d, block_size[1], column_major_scales=True
74
+ )
75
+ output = fp8_blockwise_scaled_mm(
76
+ q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
77
+ )
78
+ else:
79
+ q_input, x_scale = per_token_group_quant_fp8(
80
+ input_2d, block_size[1], column_major_scales=False
81
+ )
82
+ output = w8a8_block_fp8_matmul(
83
+ q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
84
+ )
56
85
 
57
86
  if bias is not None:
58
87
  output = output + bias
@@ -210,6 +210,7 @@ class DetokenizerManager:
210
210
  input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
211
211
  output_top_logprobs_val=recv_obj.output_top_logprobs_val,
212
212
  output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
213
+ output_hidden_states=recv_obj.output_hidden_states,
213
214
  )
214
215
  )
215
216
 
@@ -371,6 +371,8 @@ class BatchTokenIDOut:
371
371
  output_top_logprobs_val: List[List]
372
372
  output_top_logprobs_idx: List[List]
373
373
 
374
+ output_hidden_states: List[List[float]]
375
+
374
376
 
375
377
  @dataclass
376
378
  class BatchStrOut:
@@ -397,6 +399,8 @@ class BatchStrOut:
397
399
  output_top_logprobs_val: List[List]
398
400
  output_top_logprobs_idx: List[List]
399
401
 
402
+ output_hidden_states: List[List[float]]
403
+
400
404
 
401
405
  @dataclass
402
406
  class BatchEmbeddingOut:
@@ -65,6 +65,7 @@ global_server_args_dict = {
65
65
  "enable_dp_attention": ServerArgs.enable_dp_attention,
66
66
  "enable_ep_moe": ServerArgs.enable_ep_moe,
67
67
  "device": ServerArgs.device,
68
+ "enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
68
69
  }
69
70
 
70
71
  logger = logging.getLogger(__name__)
@@ -315,6 +316,7 @@ class Req:
315
316
  self.output_token_logprobs_val = self.output_token_logprobs_idx = (
316
317
  self.output_top_logprobs_val
317
318
  ) = self.output_top_logprobs_idx = None
319
+ self.hidden_states = []
318
320
 
319
321
  # Logprobs (internal values)
320
322
  # The tokens is prefilled but need to be considered as decode tokens
@@ -604,6 +606,9 @@ class ScheduleBatch:
604
606
  # Enable custom logit processor
605
607
  enable_custom_logit_processor: bool = False
606
608
 
609
+ # Return hidden states
610
+ return_hidden_states: bool = False
611
+
607
612
  @classmethod
608
613
  def init_new(
609
614
  cls,
@@ -615,6 +620,7 @@ class ScheduleBatch:
615
620
  enable_overlap: bool,
616
621
  spec_algorithm: SpeculativeAlgorithm,
617
622
  enable_custom_logit_processor: bool,
623
+ return_hidden_states: bool = False,
618
624
  ):
619
625
  return cls(
620
626
  reqs=reqs,
@@ -629,6 +635,7 @@ class ScheduleBatch:
629
635
  device=req_to_token_pool.device,
630
636
  spec_algorithm=spec_algorithm,
631
637
  enable_custom_logit_processor=enable_custom_logit_processor,
638
+ return_hidden_states=return_hidden_states,
632
639
  )
633
640
 
634
641
  def batch_size(self):
@@ -1196,9 +1203,15 @@ class ScheduleBatch:
1196
1203
  spec_algorithm=self.spec_algorithm,
1197
1204
  spec_info=self.spec_info,
1198
1205
  capture_hidden_mode=(
1199
- getattr(self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL)
1200
- if self.spec_info
1201
- else CaptureHiddenMode.NULL
1206
+ CaptureHiddenMode.FULL
1207
+ if self.return_hidden_states
1208
+ else (
1209
+ getattr(
1210
+ self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
1211
+ )
1212
+ if self.spec_info
1213
+ else CaptureHiddenMode.NULL
1214
+ )
1202
1215
  ),
1203
1216
  )
1204
1217
 
@@ -997,6 +997,7 @@ class Scheduler:
997
997
  self.enable_overlap,
998
998
  self.spec_algorithm,
999
999
  self.server_args.enable_custom_logit_processor,
1000
+ self.server_args.return_hidden_states,
1000
1001
  )
1001
1002
  new_batch.prepare_for_extend()
1002
1003
 
@@ -1156,6 +1157,8 @@ class Scheduler:
1156
1157
  logits_output.input_token_logprobs.tolist()
1157
1158
  )
1158
1159
 
1160
+ hidden_state_offset = 0
1161
+
1159
1162
  # Check finish conditions
1160
1163
  logprob_pt = 0
1161
1164
  for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
@@ -1182,6 +1185,21 @@ class Scheduler:
1182
1185
  i, req, logprob_pt, next_token_ids, logits_output
1183
1186
  )
1184
1187
 
1188
+ if (
1189
+ self.server_args.return_hidden_states
1190
+ and logits_output.hidden_states is not None
1191
+ ):
1192
+ req.hidden_states.append(
1193
+ logits_output.hidden_states[
1194
+ hidden_state_offset : (
1195
+ hidden_state_offset := hidden_state_offset
1196
+ + len(req.origin_input_ids)
1197
+ )
1198
+ ]
1199
+ .cpu()
1200
+ .clone()
1201
+ )
1202
+
1185
1203
  if req.grammar is not None:
1186
1204
  req.grammar.accept_token(next_token_id)
1187
1205
  req.grammar.finished = req.finished()
@@ -1275,6 +1293,12 @@ class Scheduler:
1275
1293
  logits_output.next_token_top_logprobs_idx[i]
1276
1294
  )
1277
1295
 
1296
+ if (
1297
+ self.server_args.return_hidden_states
1298
+ and logits_output.hidden_states is not None
1299
+ ):
1300
+ req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
1301
+
1278
1302
  if req.grammar is not None:
1279
1303
  req.grammar.accept_token(next_token_id)
1280
1304
  req.grammar.finished = req.finished()
@@ -1398,6 +1422,7 @@ class Scheduler:
1398
1422
  completion_tokens = []
1399
1423
  cached_tokens = []
1400
1424
  spec_verify_ct = []
1425
+ hidden_states = []
1401
1426
 
1402
1427
  if return_logprob:
1403
1428
  input_token_logprobs_val = []
@@ -1464,6 +1489,8 @@ class Scheduler:
1464
1489
  output_top_logprobs_val.append(req.output_top_logprobs_val)
1465
1490
  output_top_logprobs_idx.append(req.output_top_logprobs_idx)
1466
1491
 
1492
+ hidden_states.append(req.hidden_states)
1493
+
1467
1494
  # Send to detokenizer
1468
1495
  if rids:
1469
1496
  self.send_to_detokenizer.send_pyobj(
@@ -1490,6 +1517,7 @@ class Scheduler:
1490
1517
  input_top_logprobs_idx,
1491
1518
  output_top_logprobs_val,
1492
1519
  output_top_logprobs_idx,
1520
+ hidden_states,
1493
1521
  )
1494
1522
  )
1495
1523
  else: # embedding or reward model
@@ -1553,6 +1581,7 @@ class Scheduler:
1553
1581
  self.enable_overlap,
1554
1582
  self.spec_algorithm,
1555
1583
  self.server_args.enable_custom_logit_processor,
1584
+ self.server_args.return_hidden_states,
1556
1585
  )
1557
1586
  idle_batch.prepare_for_idle()
1558
1587
  return idle_batch
@@ -796,6 +796,12 @@ class TokenizerManager:
796
796
  }
797
797
  )
798
798
 
799
+ if (
800
+ hasattr(recv_obj, "output_hidden_states")
801
+ and len(recv_obj.output_hidden_states[i]) > 0
802
+ ):
803
+ meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
804
+
799
805
  if isinstance(recv_obj, BatchStrOut):
800
806
  out_dict = {
801
807
  "text": recv_obj.output_strs[i],
@@ -156,6 +156,10 @@ class TpModelWorkerClient:
156
156
  logits_output.input_token_logprobs = (
157
157
  logits_output.input_token_logprobs.to("cpu", non_blocking=True)
158
158
  )
159
+ if logits_output.hidden_states is not None:
160
+ logits_output.hidden_states = logits_output.hidden_states.to(
161
+ "cpu", non_blocking=True
162
+ )
159
163
  next_token_ids = next_token_ids.to("cpu", non_blocking=True)
160
164
  copy_done.record()
161
165
 
@@ -33,6 +33,9 @@ from sglang.srt.model_executor.forward_batch_info import (
33
33
  ForwardBatch,
34
34
  ForwardMode,
35
35
  )
36
+ from sglang.srt.utils import is_hip
37
+
38
+ is_hip_ = is_hip()
36
39
 
37
40
  if TYPE_CHECKING:
38
41
  from sglang.srt.model_executor.model_runner import ModelRunner
@@ -129,6 +132,8 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
129
132
  if bs <= model_runner.req_to_token_pool.size
130
133
  and bs <= server_args.cuda_graph_max_bs
131
134
  ]
135
+ if is_hip_:
136
+ capture_bs += [i * 8 for i in range(21, 33)]
132
137
  compile_bs = (
133
138
  [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
134
139
  if server_args.enable_torch_compile
@@ -349,7 +354,13 @@ class CudaGraphRunner:
349
354
  spec_algorithm=self.model_runner.spec_algorithm,
350
355
  spec_info=spec_info,
351
356
  capture_hidden_mode=(
352
- spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
357
+ CaptureHiddenMode.FULL
358
+ if self.model_runner.server_args.return_hidden_states
359
+ else (
360
+ spec_info.capture_hidden_mode
361
+ if spec_info
362
+ else CaptureHiddenMode.NULL
363
+ )
353
364
  ),
354
365
  )
355
366