sglang 0.4.0__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_offline_throughput.py +18 -6
  3. sglang/bench_one_batch.py +13 -0
  4. sglang/bench_serving.py +8 -1
  5. sglang/check_env.py +140 -48
  6. sglang/lang/backend/runtime_endpoint.py +1 -0
  7. sglang/lang/chat_template.py +32 -0
  8. sglang/llama3_eval.py +316 -0
  9. sglang/srt/constrained/outlines_backend.py +5 -0
  10. sglang/srt/constrained/xgrammar_backend.py +9 -6
  11. sglang/srt/layers/attention/__init__.py +5 -2
  12. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  13. sglang/srt/layers/attention/flashinfer_backend.py +22 -5
  14. sglang/srt/layers/attention/torch_native_backend.py +22 -8
  15. sglang/srt/layers/attention/triton_backend.py +38 -33
  16. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  17. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  18. sglang/srt/layers/ep_moe/__init__.py +0 -0
  19. sglang/srt/layers/ep_moe/kernels.py +349 -0
  20. sglang/srt/layers/ep_moe/layer.py +665 -0
  21. sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
  22. sglang/srt/layers/fused_moe_triton/layer.py +1 -1
  23. sglang/srt/layers/logits_processor.py +133 -95
  24. sglang/srt/layers/quantization/__init__.py +2 -47
  25. sglang/srt/layers/quantization/fp8.py +607 -0
  26. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  27. sglang/srt/layers/radix_attention.py +11 -2
  28. sglang/srt/layers/sampler.py +29 -5
  29. sglang/srt/layers/torchao_utils.py +58 -45
  30. sglang/srt/managers/detokenizer_manager.py +37 -17
  31. sglang/srt/managers/io_struct.py +39 -10
  32. sglang/srt/managers/schedule_batch.py +39 -24
  33. sglang/srt/managers/schedule_policy.py +64 -5
  34. sglang/srt/managers/scheduler.py +236 -197
  35. sglang/srt/managers/tokenizer_manager.py +99 -58
  36. sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
  37. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  38. sglang/srt/mem_cache/chunk_cache.py +2 -2
  39. sglang/srt/mem_cache/memory_pool.py +5 -1
  40. sglang/srt/mem_cache/radix_cache.py +12 -2
  41. sglang/srt/model_executor/cuda_graph_runner.py +39 -11
  42. sglang/srt/model_executor/model_runner.py +24 -9
  43. sglang/srt/model_parallel.py +67 -10
  44. sglang/srt/models/commandr.py +2 -2
  45. sglang/srt/models/deepseek_v2.py +87 -7
  46. sglang/srt/models/gemma2.py +34 -0
  47. sglang/srt/models/gemma2_reward.py +0 -1
  48. sglang/srt/models/granite.py +517 -0
  49. sglang/srt/models/grok.py +72 -13
  50. sglang/srt/models/llama.py +22 -5
  51. sglang/srt/models/llama_classification.py +11 -23
  52. sglang/srt/models/llama_reward.py +0 -2
  53. sglang/srt/models/llava.py +37 -14
  54. sglang/srt/models/mixtral.py +12 -9
  55. sglang/srt/models/phi3_small.py +0 -5
  56. sglang/srt/models/qwen2.py +20 -0
  57. sglang/srt/models/qwen2_moe.py +0 -5
  58. sglang/srt/models/torch_native_llama.py +0 -5
  59. sglang/srt/openai_api/adapter.py +4 -0
  60. sglang/srt/openai_api/protocol.py +9 -4
  61. sglang/srt/sampling/sampling_batch_info.py +9 -8
  62. sglang/srt/server.py +4 -4
  63. sglang/srt/server_args.py +62 -13
  64. sglang/srt/utils.py +57 -10
  65. sglang/test/test_utils.py +3 -2
  66. sglang/utils.py +10 -3
  67. sglang/version.py +1 -1
  68. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +15 -9
  69. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +72 -65
  70. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
  71. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
  72. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
@@ -284,6 +284,9 @@ def extend_attention_fwd(
284
284
  elif Lq == 288:
285
285
  BLOCK_DMODEL = 256
286
286
  BLOCK_DPE = 32
287
+ elif Lq == 192:
288
+ BLOCK_DMODEL = 128
289
+ BLOCK_DPE = 64
287
290
  else:
288
291
  BLOCK_DMODEL = triton.next_power_of_2(Lq)
289
292
  BLOCK_DPE = 0
File without changes
@@ -0,0 +1,349 @@
1
+ import logging
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ @triton.jit
12
+ def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
13
+ expert = tl.program_id(0)
14
+ low = 0
15
+ high = num_toks - 1
16
+ target_location = -1
17
+ while low <= high:
18
+ mid = (low + high) // 2
19
+
20
+ if tl.load(reorder_topk_ids + mid) > expert:
21
+ high = mid - 1
22
+ else:
23
+ low = mid + 1
24
+ target_location = mid
25
+ tl.store(seg_indptr + expert + 1, target_location + 1)
26
+
27
+
28
+ @triton.jit
29
+ def compute_src2dst_triton_kernel(
30
+ reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr
31
+ ):
32
+ pid = tl.program_id(axis=0)
33
+ dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
34
+ mask = dst_id < num_toks
35
+ src_id = tl.load(reorder_ids + dst_id, mask=mask)
36
+ tl.store(src2dst + src_id, dst_id, mask=mask)
37
+
38
+
39
+ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
40
+ reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
41
+ seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
42
+ src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
43
+
44
+ compute_seg_indptr_triton_kernel[(num_experts,)](
45
+ reorder_topk_ids, seg_indptr, topk_ids.numel()
46
+ )
47
+
48
+ BLOCK_SIZE = 512
49
+ grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),)
50
+ compute_src2dst_triton_kernel[grid](
51
+ reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE
52
+ )
53
+ return reorder_topk_ids, src2dst, seg_indptr
54
+
55
+
56
+ @triton.jit
57
+ def pre_reorder_triton_kernel(
58
+ input_ptr,
59
+ gateup_input_ptr,
60
+ src2dst_ptr,
61
+ topk_ids_ptr,
62
+ a1_scales_ptr,
63
+ start_expert_id,
64
+ end_expert_id,
65
+ topk,
66
+ hidden_size,
67
+ BLOCK_SIZE: tl.constexpr,
68
+ ):
69
+ OutDtype = gateup_input_ptr.dtype.element_ty
70
+
71
+ src_idx = tl.program_id(0)
72
+ src2dst_ptr = src2dst_ptr + src_idx * topk
73
+ topk_ids_ptr = topk_ids_ptr + src_idx * topk
74
+
75
+ src_ptr = input_ptr + src_idx * hidden_size
76
+ for idx in range(topk):
77
+ expert_id = tl.load(topk_ids_ptr + idx)
78
+ if expert_id >= start_expert_id and expert_id <= end_expert_id:
79
+ if a1_scales_ptr is not None:
80
+ scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id)
81
+ else:
82
+ scale = 1.0
83
+
84
+ dst_idx = tl.load(src2dst_ptr + idx)
85
+ dst_ptr = gateup_input_ptr + dst_idx * hidden_size
86
+ for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
87
+ offset = start_offset + tl.arange(0, BLOCK_SIZE)
88
+ mask = offset < hidden_size
89
+ in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
90
+ out_data = (in_data * scale).to(OutDtype)
91
+ tl.store(dst_ptr + offset, out_data, mask=mask)
92
+
93
+
94
+ @triton.jit
95
+ def silu_and_mul_triton_kernel(
96
+ gateup_output,
97
+ down_input,
98
+ hidden_size,
99
+ reorder_topk_ids,
100
+ scales,
101
+ start_expert_id,
102
+ end_expert_id,
103
+ BLOCK_SIZE: tl.constexpr,
104
+ ):
105
+ InDtype = gateup_output.dtype.element_ty
106
+ OutDtype = down_input.dtype.element_ty
107
+
108
+ half_hidden_size = hidden_size // 2
109
+
110
+ pid = tl.program_id(0)
111
+ expert_id = tl.load(reorder_topk_ids + pid)
112
+ if expert_id >= start_expert_id and expert_id <= end_expert_id:
113
+ gateup_output_ptr = gateup_output + pid * hidden_size
114
+ gate_output_ptr = gateup_output_ptr
115
+ up_output_ptr = gateup_output_ptr + half_hidden_size
116
+ down_input_ptr = down_input + pid * half_hidden_size
117
+
118
+ if scales is not None:
119
+ scale = tl.load(scales + expert_id - start_expert_id)
120
+ scale = (1 / scale).to(InDtype)
121
+ else:
122
+ scale = 1
123
+
124
+ for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
125
+ offset = start_offset + tl.arange(0, BLOCK_SIZE)
126
+ mask = offset < half_hidden_size
127
+
128
+ gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
129
+ up_output = tl.load(up_output_ptr + offset, mask=mask)
130
+
131
+ # silu & mul & quantize
132
+ gate_output = gate_output * tl.sigmoid(gate_output)
133
+ gate_output = gate_output.to(InDtype)
134
+
135
+ silu_mul_output = gate_output * up_output * scale
136
+ silu_mul_output = silu_mul_output.to(OutDtype)
137
+ tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
138
+
139
+
140
+ @triton.jit
141
+ def post_reorder_triton_kernel(
142
+ down_output_ptr,
143
+ output_ptr,
144
+ src2dst_ptr,
145
+ topk_ids_ptr,
146
+ topk_weights_ptr,
147
+ start_expert_id,
148
+ end_expert_id,
149
+ topk,
150
+ hidden_size,
151
+ BLOCK_SIZE: tl.constexpr,
152
+ ):
153
+ InDtype = down_output_ptr.dtype.element_ty
154
+
155
+ src_idx = tl.program_id(0)
156
+ src2dst_ptr = src2dst_ptr + src_idx * topk
157
+ topk_ids_ptr = topk_ids_ptr + src_idx * topk
158
+ topk_weights_ptr = topk_weights_ptr + src_idx * topk
159
+
160
+ computed = False
161
+ store_ptr = output_ptr + src_idx * hidden_size
162
+ for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
163
+ offset = start_offset + tl.arange(0, BLOCK_SIZE)
164
+ mask = offset < hidden_size
165
+
166
+ sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
167
+ for idx in range(topk):
168
+ expert_id = tl.load(topk_ids_ptr + idx)
169
+ if expert_id >= start_expert_id and expert_id <= end_expert_id:
170
+ computed = True
171
+ dst_idx = tl.load(src2dst_ptr + idx)
172
+ weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
173
+ load_ptr = down_output_ptr + dst_idx * hidden_size
174
+ in_data = tl.load(load_ptr + offset, mask=mask)
175
+ sum_vec += in_data * weigh_scale
176
+ tl.store(store_ptr + offset, sum_vec, mask=mask)
177
+
178
+ if computed == False:
179
+ for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
180
+ offset = start_offset + tl.arange(0, BLOCK_SIZE)
181
+ mask = offset < hidden_size
182
+ tl.store(
183
+ store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask
184
+ )
185
+
186
+
187
+ @triton.jit
188
+ def compute_m_range(
189
+ pid,
190
+ batch_size,
191
+ seg_indptr,
192
+ weight_indices,
193
+ m_num_tiles_indptr,
194
+ BLOCK_SIZE_M: tl.constexpr,
195
+ ):
196
+ idx = 0
197
+ for bs in range(batch_size):
198
+ tiles = tl.load(m_num_tiles_indptr + bs)
199
+ if pid >= tiles:
200
+ idx = bs
201
+
202
+ idx_start = tl.load(m_num_tiles_indptr + idx)
203
+
204
+ m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M
205
+ m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M)
206
+ expert_id = tl.load(weight_indices + idx)
207
+ return m_range_start, m_range_end, expert_id
208
+
209
+
210
+ @triton.jit
211
+ def grouped_gemm_triton_kernel(
212
+ a,
213
+ b,
214
+ c,
215
+ batch_size,
216
+ N,
217
+ K,
218
+ seg_indptr,
219
+ weight_indices,
220
+ m_num_tiles_indptr,
221
+ use_fp8_w8a8,
222
+ scale_a,
223
+ scale_b,
224
+ a_stride_0: tl.constexpr,
225
+ b_stride_0: tl.constexpr,
226
+ b_stride_1: tl.constexpr,
227
+ BLOCK_SIZE_M: tl.constexpr,
228
+ BLOCK_SIZE_N: tl.constexpr,
229
+ BLOCK_SIZE_K: tl.constexpr,
230
+ ):
231
+ c_dtype = c.dtype.element_ty
232
+
233
+ pid_m = tl.program_id(0)
234
+ pid_n = tl.program_id(1)
235
+ total_m_block = tl.load(m_num_tiles_indptr + batch_size)
236
+ if pid_m >= total_m_block:
237
+ return
238
+
239
+ m_range_start, m_range_end, expert_id = compute_m_range(
240
+ pid_m, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M
241
+ )
242
+ if m_range_end - m_range_start == 0:
243
+ return
244
+
245
+ n_range_start = pid_n * BLOCK_SIZE_N
246
+ n_range_end = min(n_range_start + BLOCK_SIZE_N, N)
247
+
248
+ offs_am = tl.arange(0, BLOCK_SIZE_M)
249
+ offs_bn = tl.arange(0, BLOCK_SIZE_N)
250
+
251
+ offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0)
252
+ offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0)
253
+ offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
254
+ offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
255
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
256
+
257
+ a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :]
258
+ b_ptr = b + (
259
+ (expert_id * b_stride_0)
260
+ + (n_range_start + offs_bn[:, None]) * b_stride_1
261
+ + offs_k[None, :]
262
+ )
263
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
264
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
265
+ a_tile = tl.load(
266
+ a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
267
+ )
268
+ b_tile = tl.load(
269
+ b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
270
+ )
271
+ accumulator = tl.dot(a_tile, b_tile.T, accumulator)
272
+ a_ptr += BLOCK_SIZE_K
273
+ b_ptr += BLOCK_SIZE_K
274
+
275
+ if use_fp8_w8a8:
276
+ scale_a_value = tl.load(scale_a + expert_id)
277
+ scale_b_value = tl.load(scale_b + expert_id)
278
+ accumulator *= scale_a_value * scale_b_value
279
+ c_tile = accumulator.to(c_dtype)
280
+
281
+ offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M)
282
+ offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N)
283
+ c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :]
284
+ c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end)
285
+ tl.store(c_ptr, c_tile, mask=c_mask)
286
+
287
+
288
+ @triton.jit
289
+ def compute_m_num_tiles_indptr(
290
+ m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr
291
+ ):
292
+ for bs in range(batch_size):
293
+ m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs)
294
+ cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M)
295
+ pre_num_tiles = tl.load(m_num_tiles_indptr + bs)
296
+ tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles)
297
+
298
+
299
+ def grouped_gemm_triton(
300
+ a: torch.Tensor,
301
+ b: torch.Tensor,
302
+ c: torch.Tensor,
303
+ batch_size: int,
304
+ weight_column_major: bool,
305
+ seg_indptr: Optional[torch.Tensor] = None,
306
+ weight_indices: Optional[torch.Tensor] = None,
307
+ use_fp8_w8a8: bool = False,
308
+ scale_a: torch.Tensor = None,
309
+ scale_b: torch.Tensor = None,
310
+ ):
311
+ assert weight_column_major == True # TODO: more
312
+ if use_fp8_w8a8:
313
+ assert scale_a is not None and scale_b is not None
314
+
315
+ config = {
316
+ "BLOCK_SIZE_M": 128,
317
+ "BLOCK_SIZE_N": 128,
318
+ "BLOCK_SIZE_K": 128,
319
+ }
320
+
321
+ m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64)
322
+ compute_m_num_tiles_indptr[(1,)](
323
+ m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"]
324
+ )
325
+
326
+ grid = lambda META: (
327
+ triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size,
328
+ triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
329
+ )
330
+
331
+ grouped_gemm_triton_kernel[grid](
332
+ a,
333
+ b,
334
+ c,
335
+ batch_size,
336
+ b.size(1),
337
+ b.size(2),
338
+ seg_indptr,
339
+ weight_indices,
340
+ m_num_tiles_indptr,
341
+ use_fp8_w8a8,
342
+ scale_a,
343
+ scale_b,
344
+ a.stride(0),
345
+ b.stride(0),
346
+ b.stride(1),
347
+ **config,
348
+ )
349
+ return c