sglang 0.4.2__py3-none-any.whl → 0.4.2.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/srt/constrained/outlines_backend.py +9 -1
  2. sglang/srt/custom_op.py +40 -0
  3. sglang/srt/entrypoints/engine.py +2 -2
  4. sglang/srt/layers/activation.py +10 -5
  5. sglang/srt/layers/attention/flashinfer_backend.py +284 -39
  6. sglang/srt/layers/attention/triton_backend.py +71 -7
  7. sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
  8. sglang/srt/layers/attention/triton_ops/prefill_attention.py +6 -0
  9. sglang/srt/layers/attention/vision.py +243 -40
  10. sglang/srt/layers/layernorm.py +1 -5
  11. sglang/srt/layers/moe/ep_moe/layer.py +1 -3
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
  20. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -11
  21. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
  22. sglang/srt/layers/moe/topk.py +4 -0
  23. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  24. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  30. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  32. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  33. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  34. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  36. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  40. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  42. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  44. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/fp8.py +7 -0
  46. sglang/srt/layers/quantization/fp8_kernel.py +140 -2
  47. sglang/srt/layers/rotary_embedding.py +29 -15
  48. sglang/srt/layers/sampler.py +9 -6
  49. sglang/srt/lora/backend/__init__.py +8 -0
  50. sglang/srt/lora/backend/base_backend.py +95 -0
  51. sglang/srt/lora/backend/flashinfer_backend.py +91 -0
  52. sglang/srt/lora/backend/triton_backend.py +61 -0
  53. sglang/srt/lora/lora.py +127 -112
  54. sglang/srt/lora/lora_manager.py +50 -18
  55. sglang/srt/lora/triton_ops/__init__.py +5 -0
  56. sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
  57. sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
  58. sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
  59. sglang/srt/managers/image_processor.py +77 -38
  60. sglang/srt/managers/scheduler.py +17 -3
  61. sglang/srt/mem_cache/base_prefix_cache.py +4 -0
  62. sglang/srt/mem_cache/chunk_cache.py +3 -0
  63. sglang/srt/mem_cache/radix_cache.py +30 -1
  64. sglang/srt/model_executor/cuda_graph_runner.py +77 -80
  65. sglang/srt/model_executor/forward_batch_info.py +58 -59
  66. sglang/srt/model_executor/model_runner.py +2 -2
  67. sglang/srt/models/minicpmv.py +129 -76
  68. sglang/srt/models/mllama.py +16 -56
  69. sglang/srt/models/qwen2.py +4 -1
  70. sglang/srt/models/qwen2_vl.py +19 -9
  71. sglang/srt/server_args.py +19 -2
  72. sglang/srt/speculative/build_eagle_tree.py +4 -2
  73. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
  74. sglang/srt/speculative/eagle_utils.py +361 -372
  75. sglang/srt/speculative/eagle_worker.py +177 -45
  76. sglang/srt/utils.py +7 -2
  77. sglang/test/runners.py +2 -0
  78. sglang/utils.py +42 -0
  79. sglang/version.py +1 -1
  80. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/METADATA +16 -7
  81. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/RECORD +84 -45
  82. sglang/srt/layers/custom_op_util.py +0 -25
  83. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/LICENSE +0 -0
  84. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/WHEEL +0 -0
  85. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 32,
4
+ "BLOCK_SIZE_N": 32,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 8,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 32,
15
+ "num_warps": 8,
16
+ "num_stages": 5
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 32,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 32,
23
+ "num_warps": 8,
24
+ "num_stages": 2
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 64,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 64,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 4,
32
+ "num_stages": 5
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 32,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 64,
39
+ "num_warps": 4,
40
+ "num_stages": 2
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 16,
47
+ "num_warps": 4,
48
+ "num_stages": 2
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 64,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 16,
55
+ "num_warps": 8,
56
+ "num_stages": 5
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 64,
63
+ "num_warps": 8,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 64,
68
+ "BLOCK_SIZE_N": 32,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 16,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 128,
76
+ "BLOCK_SIZE_N": 64,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 32,
79
+ "num_warps": 8,
80
+ "num_stages": 5
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 64,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 64,
87
+ "num_warps": 8,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 256,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 8,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 32,
111
+ "num_warps": 4,
112
+ "num_stages": 2
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 4,
120
+ "num_stages": 2
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 4,
128
+ "num_stages": 2
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 32,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 64,
143
+ "num_warps": 4,
144
+ "num_stages": 2
145
+ }
146
+ }
@@ -290,6 +290,13 @@ class Fp8LinearMethod(LinearMethodBase):
290
290
  weight_scale, requires_grad=False
291
291
  )
292
292
  layer.input_scale = None
293
+ else:
294
+ layer.weight = torch.nn.Parameter(
295
+ layer.weight.data, requires_grad=False
296
+ )
297
+ layer.weight_scale_inv = torch.nn.Parameter(
298
+ layer.weight_scale_inv.data, requires_grad=False
299
+ )
293
300
  return
294
301
  layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
295
302
  # If checkpoint not serialized fp8, quantize the weights.
@@ -22,7 +22,7 @@ import torch
22
22
  import triton
23
23
  import triton.language as tl
24
24
 
25
- from sglang.srt.utils import get_device_name, is_hip
25
+ from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
26
26
 
27
27
  is_hip_ = is_hip()
28
28
  fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
@@ -220,6 +220,132 @@ def _w8a8_block_fp8_matmul(
220
220
  tl.store(c_ptrs, c, mask=c_mask)
221
221
 
222
222
 
223
+ @triton.jit
224
+ def _w8a8_block_fp8_matmul_unrolledx4(
225
+ # Pointers to inputs and output
226
+ A,
227
+ B,
228
+ C,
229
+ As,
230
+ Bs,
231
+ # Shape for matmul
232
+ M,
233
+ N,
234
+ K,
235
+ # Block size for block-wise quantization
236
+ group_n,
237
+ group_k,
238
+ # Stride for inputs and output
239
+ stride_am,
240
+ stride_ak,
241
+ stride_bk,
242
+ stride_bn,
243
+ stride_cm,
244
+ stride_cn,
245
+ stride_As_m,
246
+ stride_As_k,
247
+ stride_Bs_k,
248
+ stride_Bs_n,
249
+ # Meta-parameters
250
+ BLOCK_SIZE_M: tl.constexpr,
251
+ BLOCK_SIZE_N: tl.constexpr,
252
+ BLOCK_SIZE_K: tl.constexpr,
253
+ GROUP_SIZE_M: tl.constexpr,
254
+ ):
255
+ """Triton-accelerated function used to perform linear operations (dot
256
+ product) on input tensors `A` and `B` with block-wise quantization, and store the result in output
257
+ tensor `C`.
258
+ """
259
+
260
+ pid = tl.program_id(axis=0)
261
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
262
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
263
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
264
+ group_id = pid // num_pid_in_group
265
+ first_pid_m = group_id * GROUP_SIZE_M
266
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
267
+ pid_m = first_pid_m + (pid % group_size_m)
268
+ pid_n = (pid % num_pid_in_group) // group_size_m
269
+
270
+ offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
271
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
272
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
273
+ a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
274
+ b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
275
+
276
+ As_ptrs = As + offs_am * stride_As_m
277
+ offs_bsn = offs_bn // group_n
278
+ Bs_ptrs = Bs + offs_bsn * stride_Bs_n
279
+
280
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
281
+ # manually unroll to 4 iterations
282
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K) // 4):
283
+ # 1st iteration
284
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
285
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
286
+
287
+ k_start = k * BLOCK_SIZE_K
288
+ offs_ks = k_start // group_k
289
+ a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
290
+ b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
291
+
292
+ accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
293
+ a_ptrs += BLOCK_SIZE_K * stride_ak
294
+ b_ptrs += BLOCK_SIZE_K * stride_bk
295
+
296
+ # 2nd iteration
297
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
298
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
299
+
300
+ k_start = k_start + BLOCK_SIZE_K
301
+ offs_ks = k_start // group_k
302
+ a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
303
+ b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
304
+
305
+ accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
306
+ a_ptrs += BLOCK_SIZE_K * stride_ak
307
+ b_ptrs += BLOCK_SIZE_K * stride_bk
308
+
309
+ # 3rd iteration
310
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
311
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
312
+
313
+ k_start = k_start + BLOCK_SIZE_K
314
+ offs_ks = k_start // group_k
315
+ a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
316
+ b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
317
+
318
+ accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
319
+ a_ptrs += BLOCK_SIZE_K * stride_ak
320
+ b_ptrs += BLOCK_SIZE_K * stride_bk
321
+
322
+ # 4th iteration
323
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
324
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
325
+
326
+ k_start = k_start + BLOCK_SIZE_K
327
+ offs_ks = k_start // group_k
328
+ a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
329
+ b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
330
+
331
+ accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
332
+ a_ptrs += BLOCK_SIZE_K * stride_ak
333
+ b_ptrs += BLOCK_SIZE_K * stride_bk
334
+
335
+ if C.dtype.element_ty == tl.bfloat16:
336
+ c = accumulator.to(tl.bfloat16)
337
+ elif C.dtype.element_ty == tl.float16:
338
+ c = accumulator.to(tl.float16)
339
+ else:
340
+ c = accumulator.to(tl.float32)
341
+
342
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
343
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
344
+ c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
345
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
346
+ tl.store(c_ptrs, c, mask=c_mask)
347
+
348
+
223
349
  @functools.lru_cache
224
350
  def get_w8a8_block_fp8_configs(
225
351
  N: int, K: int, block_n: int, block_k: int
@@ -324,7 +450,19 @@ def w8a8_block_fp8_matmul(
324
450
  triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
325
451
  )
326
452
 
327
- _w8a8_block_fp8_matmul[grid](
453
+ # Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
454
+ # Empirical testing shows the sweet spot lies when it's less than the # of
455
+ # compute units available on the device.
456
+ num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
457
+ N, config["BLOCK_SIZE_N"]
458
+ )
459
+ kernel = (
460
+ _w8a8_block_fp8_matmul_unrolledx4
461
+ if (is_hip_ == True and num_workgroups <= get_device_core_count())
462
+ else _w8a8_block_fp8_matmul
463
+ )
464
+
465
+ kernel[grid](
328
466
  A,
329
467
  B,
330
468
  C,
@@ -6,9 +6,14 @@ from typing import Any, Dict, List, Optional, Tuple, Union
6
6
 
7
7
  import torch
8
8
  import torch.nn as nn
9
- from vllm.model_executor.custom_op import CustomOp
9
+ from vllm import _custom_ops as ops
10
10
 
11
- from sglang.srt.layers.custom_op_util import register_custom_op
11
+ from sglang.srt.custom_op import CustomOp
12
+ from sglang.srt.utils import is_cuda_available
13
+
14
+ _is_cuda_available = is_cuda_available()
15
+ if _is_cuda_available:
16
+ from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
12
17
 
13
18
 
14
19
  def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -53,7 +58,6 @@ def _apply_rotary_emb(
53
58
  return torch.stack((o1, o2), dim=-1).flatten(-2)
54
59
 
55
60
 
56
- @register_custom_op("sglang_rotary_embedding")
57
61
  class RotaryEmbedding(CustomOp):
58
62
  """Original rotary positional embedding."""
59
63
 
@@ -75,7 +79,9 @@ class RotaryEmbedding(CustomOp):
75
79
  self.dtype = dtype
76
80
 
77
81
  cache = self._compute_cos_sin_cache()
78
- cache = cache.to(dtype)
82
+ # NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
83
+ if not _is_cuda_available:
84
+ cache = cache.to(dtype)
79
85
  self.cos_sin_cache: torch.Tensor
80
86
  self.register_buffer("cos_sin_cache", cache, persistent=False)
81
87
 
@@ -141,17 +147,25 @@ class RotaryEmbedding(CustomOp):
141
147
  key: torch.Tensor,
142
148
  offsets: Optional[torch.Tensor] = None,
143
149
  ) -> Tuple[torch.Tensor, torch.Tensor]:
144
- from vllm import _custom_ops as ops
145
-
146
- self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
147
- ops.rotary_embedding(
148
- positions,
149
- query,
150
- key,
151
- self.head_size,
152
- self.cos_sin_cache,
153
- self.is_neox_style,
154
- )
150
+ if _is_cuda_available:
151
+ apply_rope_with_cos_sin_cache_inplace(
152
+ positions=positions,
153
+ query=query,
154
+ key=key,
155
+ head_size=self.head_size,
156
+ cos_sin_cache=self.cos_sin_cache,
157
+ is_neox=self.is_neox_style,
158
+ )
159
+ else:
160
+ self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
161
+ ops.rotary_embedding(
162
+ positions,
163
+ query,
164
+ key,
165
+ self.head_size,
166
+ self.cos_sin_cache,
167
+ self.is_neox_style,
168
+ )
155
169
  return query, key
156
170
 
157
171
  def forward_xpu(
@@ -72,9 +72,11 @@ class Sampler(nn.Module):
72
72
  # NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
73
73
  # https://github.com/flashinfer-ai/flashinfer/issues/708
74
74
  # so we use the torch implementation.
75
+
76
+ # clamp to avoid -inf
75
77
  logprobs = torch.log(
76
78
  top_p_normalize_probs_torch(probs, sampling_info.top_ps)
77
- )
79
+ ).clamp(min=torch.finfo(probs.dtype).min)
78
80
 
79
81
  max_top_k_round, batch_size = 32, probs.shape[0]
80
82
  uniform_samples = torch.rand(
@@ -83,7 +85,7 @@ class Sampler(nn.Module):
83
85
  if sampling_info.need_min_p_sampling:
84
86
  probs = top_k_renorm_prob(probs, sampling_info.top_ks)
85
87
  probs = top_p_renorm_prob(probs, sampling_info.top_ps)
86
- batch_next_token_ids, success = min_p_sampling_from_probs(
88
+ batch_next_token_ids = min_p_sampling_from_probs(
87
89
  probs, uniform_samples, sampling_info.min_ps
88
90
  )
89
91
  else:
@@ -95,9 +97,9 @@ class Sampler(nn.Module):
95
97
  filter_apply_order="joint",
96
98
  )
97
99
 
98
- if self.use_nan_detectioin and not torch.all(success):
99
- logger.warning("Detected errors during sampling!")
100
- batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
100
+ if self.use_nan_detectioin and not torch.all(success):
101
+ logger.warning("Detected errors during sampling!")
102
+ batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
101
103
 
102
104
  elif global_server_args_dict["sampling_backend"] == "pytorch":
103
105
  # A slower fallback implementation with torch native operations.
@@ -109,9 +111,10 @@ class Sampler(nn.Module):
109
111
  sampling_info.need_min_p_sampling,
110
112
  )
111
113
  if return_logprob:
114
+ # clamp to avoid -inf
112
115
  logprobs = torch.log(
113
116
  top_p_normalize_probs_torch(probs, sampling_info.top_ps)
114
- )
117
+ ).clamp(min=torch.finfo(probs.dtype).min)
115
118
  else:
116
119
  raise ValueError(
117
120
  f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
@@ -0,0 +1,8 @@
1
+ from .base_backend import BaseLoraBackend
2
+ from .flashinfer_backend import FlashInferLoraBackend
3
+ from .triton_backend import TritonLoraBackend
4
+
5
+ __all__ = [
6
+ "FlashInferLoraBackend",
7
+ "TritonLoraBackend",
8
+ ]
@@ -0,0 +1,95 @@
1
+ from typing import Tuple, Union
2
+
3
+ import torch
4
+
5
+ from sglang.srt.lora.lora import LoraBatchInfo
6
+
7
+
8
+ def get_fuse_output_scaling_add_from_name(name: str) -> bool:
9
+ mapping = {
10
+ "triton": True,
11
+ "flashinfer": False,
12
+ }
13
+ return mapping.get(name, False)
14
+
15
+
16
+ def get_fuse_qkv_lora_b_from_name(name: str) -> bool:
17
+ mapping = {
18
+ "triton": True,
19
+ "flashinfer": False,
20
+ }
21
+ return mapping.get(name, False)
22
+
23
+
24
+ class BaseLoraBackend:
25
+ """Base class for different Lora backends.
26
+ Each backend has its own implementation of Lora kernels.
27
+
28
+ Args:
29
+ name: name of backend
30
+ batch_info: information of current batch for use
31
+ fuse_output_scaling_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
32
+ and the operation of scaling and adding will be fused into kernel
33
+ """
34
+
35
+ def __init__(self, name: str, batch_info: LoraBatchInfo = None):
36
+ self.name = name
37
+ self.batch_info = batch_info
38
+ self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name)
39
+ self.fuse_qkv_lora_b = get_fuse_qkv_lora_b_from_name(name)
40
+
41
+ def run_lora_a_sgemm(
42
+ self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
43
+ ) -> torch.Tensor:
44
+ """Run segment Gemm of lora a modules with current backend.
45
+ The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
46
+
47
+ Args:
48
+ x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
49
+ weights: a set of lora weights with shape (num_lora, r, input_dim), here r is lora rank
50
+ usually input_dim is much larger than r
51
+ Returns:
52
+ result with shape (s, r)
53
+ """
54
+ pass
55
+
56
+ def run_lora_b_sgemm(
57
+ self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
58
+ ) -> torch.Tensor:
59
+ """Run segment Gemm of lora b modules with current backend.
60
+ The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
61
+
62
+ Args:
63
+ x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank
64
+ weights: a set of lora weights with shape (num_lora, output_dim, r)
65
+ usually output_dim is much larger than r
66
+ Returns:
67
+ result with shape (s, output_dim)
68
+ """
69
+ pass
70
+
71
+ def run_qkv_lora(
72
+ self,
73
+ x: torch.Tensor,
74
+ qkv_lora_a: torch.Tensor,
75
+ qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
76
+ *args,
77
+ **kwargs
78
+ ) -> torch.Tensor:
79
+ """Run the lora pass for QKV Layer.
80
+
81
+ Args:
82
+ x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
83
+ qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
84
+ qkv_lora_b: lora_b module for qkv.
85
+ If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
86
+ If passed in as a tuple of two tensors containing:
87
+ a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
88
+ and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
89
+ Returns:
90
+ result with shape (s, output_dim_q + 2 * output_dim_kv)
91
+ """
92
+ pass
93
+
94
+ def set_batch_info(self, batch_info: LoraBatchInfo):
95
+ self.batch_info = batch_info
@@ -0,0 +1,91 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+
5
+ from sglang.srt.lora.backend import BaseLoraBackend
6
+ from sglang.srt.lora.lora import LoraBatchInfo
7
+ from sglang.srt.utils import is_flashinfer_available
8
+
9
+ if is_flashinfer_available():
10
+ from flashinfer import SegmentGEMMWrapper
11
+
12
+
13
+ class FlashInferLoraBackend(BaseLoraBackend):
14
+
15
+ def __init__(self, name: str, batch_info: LoraBatchInfo = None):
16
+ super().__init__(name, batch_info)
17
+
18
+ # Set up SGemm Wrapper from flashinfer
19
+ # FIXME wait for flashinfer segment gemm update
20
+ workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
21
+ self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
22
+
23
+ def run_lora_a_sgemm(
24
+ self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
25
+ ) -> torch.Tensor:
26
+
27
+ return self.segment_gemm.run(
28
+ x=x,
29
+ weights=weights,
30
+ batch_size=self.batch_info.bs,
31
+ weight_column_major=True,
32
+ seg_indptr=self.batch_info.seg_indptr,
33
+ weight_indices=self.batch_info.weight_indices,
34
+ )
35
+
36
+ def run_lora_b_sgemm(
37
+ self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
38
+ ) -> torch.Tensor:
39
+
40
+ return self.segment_gemm.run(
41
+ x=x,
42
+ weights=weights,
43
+ batch_size=self.batch_info.bs,
44
+ weight_column_major=True,
45
+ seg_indptr=self.batch_info.seg_indptr,
46
+ weight_indices=self.batch_info.weight_indices,
47
+ )
48
+
49
+ def run_qkv_lora(
50
+ self,
51
+ x: torch.Tensor,
52
+ qkv_lora_a: torch.Tensor,
53
+ qkv_lora_b: Tuple[torch.Tensor],
54
+ *args,
55
+ **kwargs,
56
+ ) -> torch.Tensor:
57
+
58
+ # Shape of lora_a_output: (s, 3 * r)
59
+ lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a)
60
+
61
+ q_lora_b, kv_lora_b = qkv_lora_b
62
+ lora_rank = kv_lora_b.shape[-1]
63
+ output_dim_q = q_lora_b.shape[-2]
64
+ output_dim_kv = kv_lora_b.shape[-2]
65
+ lora_output = torch.empty(
66
+ (x.shape[0], output_dim_q + 2 * output_dim_kv),
67
+ device=x.device,
68
+ dtype=x.dtype,
69
+ )
70
+
71
+ # q
72
+ lora_output[:, :output_dim_q] = self.run_lora_b_sgemm(
73
+ x=lora_a_output[:, :lora_rank].contiguous(), weights=q_lora_b[0]
74
+ )
75
+
76
+ # kv
77
+ lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = (
78
+ self.run_lora_b_sgemm(
79
+ x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(),
80
+ weights=kv_lora_b[0],
81
+ )
82
+ )
83
+
84
+ lora_output[
85
+ :, output_dim_q + output_dim_kv : output_dim_q + 2 * output_dim_kv
86
+ ] = self.run_lora_b_sgemm(
87
+ x=lora_a_output[:, 2 * lora_rank : 3 * lora_rank].contiguous(),
88
+ weights=kv_lora_b[1],
89
+ )
90
+
91
+ return lora_output