sglang 0.4.2__py3-none-any.whl → 0.4.2.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/srt/constrained/outlines_backend.py +9 -1
  2. sglang/srt/custom_op.py +40 -0
  3. sglang/srt/entrypoints/engine.py +2 -2
  4. sglang/srt/layers/activation.py +10 -5
  5. sglang/srt/layers/attention/flashinfer_backend.py +284 -39
  6. sglang/srt/layers/attention/triton_backend.py +71 -7
  7. sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
  8. sglang/srt/layers/attention/triton_ops/prefill_attention.py +6 -0
  9. sglang/srt/layers/attention/vision.py +243 -40
  10. sglang/srt/layers/layernorm.py +1 -5
  11. sglang/srt/layers/moe/ep_moe/layer.py +1 -3
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
  20. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -11
  21. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
  22. sglang/srt/layers/moe/topk.py +4 -0
  23. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  24. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  30. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  32. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  33. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  34. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  36. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  40. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  42. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  44. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/fp8.py +7 -0
  46. sglang/srt/layers/quantization/fp8_kernel.py +140 -2
  47. sglang/srt/layers/rotary_embedding.py +29 -15
  48. sglang/srt/layers/sampler.py +9 -6
  49. sglang/srt/lora/backend/__init__.py +8 -0
  50. sglang/srt/lora/backend/base_backend.py +95 -0
  51. sglang/srt/lora/backend/flashinfer_backend.py +91 -0
  52. sglang/srt/lora/backend/triton_backend.py +61 -0
  53. sglang/srt/lora/lora.py +127 -112
  54. sglang/srt/lora/lora_manager.py +50 -18
  55. sglang/srt/lora/triton_ops/__init__.py +5 -0
  56. sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
  57. sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
  58. sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
  59. sglang/srt/managers/image_processor.py +77 -38
  60. sglang/srt/managers/scheduler.py +17 -3
  61. sglang/srt/mem_cache/base_prefix_cache.py +4 -0
  62. sglang/srt/mem_cache/chunk_cache.py +3 -0
  63. sglang/srt/mem_cache/radix_cache.py +30 -1
  64. sglang/srt/model_executor/cuda_graph_runner.py +77 -80
  65. sglang/srt/model_executor/forward_batch_info.py +58 -59
  66. sglang/srt/model_executor/model_runner.py +2 -2
  67. sglang/srt/models/minicpmv.py +129 -76
  68. sglang/srt/models/mllama.py +16 -56
  69. sglang/srt/models/qwen2.py +4 -1
  70. sglang/srt/models/qwen2_vl.py +19 -9
  71. sglang/srt/server_args.py +19 -2
  72. sglang/srt/speculative/build_eagle_tree.py +4 -2
  73. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
  74. sglang/srt/speculative/eagle_utils.py +361 -372
  75. sglang/srt/speculative/eagle_worker.py +177 -45
  76. sglang/srt/utils.py +7 -2
  77. sglang/test/runners.py +2 -0
  78. sglang/utils.py +42 -0
  79. sglang/version.py +1 -1
  80. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/METADATA +16 -7
  81. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/RECORD +84 -45
  82. sglang/srt/layers/custom_op_util.py +0 -25
  83. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/LICENSE +0 -0
  84. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/WHEEL +0 -0
  85. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,182 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from sglang.srt.lora.lora import LoraBatchInfo
6
+
7
+
8
+ @triton.jit
9
+ def _qkv_lora_b_kernel(
10
+ # Pointers to matrices
11
+ x,
12
+ weights,
13
+ output,
14
+ # Parameters of size
15
+ K, # K = R
16
+ max_qkv_out_dim, # max(output_q_dim, output_kv_dim)
17
+ # Strides
18
+ x_stride_0,
19
+ x_stride_1,
20
+ w_stride_0,
21
+ w_stride_1,
22
+ w_stride_2,
23
+ output_stride_0,
24
+ output_stride_1,
25
+ # Information on sequence lengths and weight id
26
+ seg_lens,
27
+ seg_indptr,
28
+ weight_indices,
29
+ # Offsets of q/k/v slice on output dimension
30
+ n_offs,
31
+ # Meta parameters
32
+ BLOCK_S: tl.constexpr,
33
+ BLOCK_N: tl.constexpr,
34
+ BLOCK_K: tl.constexpr,
35
+ # For fused output scaling and adding
36
+ fuse_scaling_add,
37
+ scaling,
38
+ ):
39
+ # This kernel packs 3 sgemms (q/k/v) into a single kernel.
40
+
41
+ # x: (s, 3 * K), s is the sum of sequence lengths, K equals to lora rank
42
+ # weights: (num_lora, N_Q + 2 * N_KV, K)
43
+ # output: (s, N_Q + 2 * N_KV)
44
+ # N_Q >> K, N_KV >> K
45
+
46
+ # Current block computes sequence with batch_id,
47
+ # which starts from row seg_start of x with length seg_len.
48
+ # qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
49
+ batch_id = tl.program_id(axis=2)
50
+ qkv_id = tl.program_id(axis=1)
51
+ pid = tl.program_id(axis=0)
52
+ seg_len = tl.load(seg_lens + batch_id)
53
+ w_index = tl.load(weight_indices + batch_id)
54
+ seg_start = tl.load(seg_indptr + batch_id)
55
+ n_start = tl.load(n_offs + qkv_id)
56
+ n_size = tl.load(n_offs + qkv_id + 1) - n_start
57
+
58
+ # The tile in output matrix will have (pid_s, pid_n) as id
59
+ num_pid_n = tl.cdiv(max_qkv_out_dim, BLOCK_N)
60
+ pid_s = pid // num_pid_n
61
+ pid_n = pid % num_pid_n
62
+
63
+ # Create pointers for the first block of x and weights[batch_id][n_start: n_end][:]
64
+ # The pointers will be advanced as we move in the K direction
65
+ # and accumulate
66
+ s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
67
+ n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
68
+ k_offset = tl.arange(0, BLOCK_K)
69
+
70
+ x_ptrs = (x + seg_start * x_stride_0 + (qkv_id * K) * x_stride_1) + (
71
+ s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
72
+ )
73
+ w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + (
74
+ k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
75
+ )
76
+
77
+ # Iteate to compute the block in output matrix
78
+ partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
79
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
80
+ x_tile = tl.load(
81
+ x_ptrs,
82
+ mask=(s_offset[:, None] < seg_len)
83
+ and (k_offset[None, :] < K - k * BLOCK_K),
84
+ other=0.0,
85
+ )
86
+ w_tile = tl.load(
87
+ w_ptrs,
88
+ mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < n_size),
89
+ other=0.0,
90
+ )
91
+ partial_sum += tl.dot(x_tile, w_tile)
92
+
93
+ x_ptrs += BLOCK_K * x_stride_1
94
+ w_ptrs += BLOCK_K * w_stride_2
95
+
96
+ # Store result to output matrix
97
+ partial_sum *= scaling
98
+ partial_sum = partial_sum.to(x.dtype.element_ty)
99
+ output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
100
+ s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
101
+ )
102
+ output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size)
103
+ if fuse_scaling_add:
104
+ partial_sum += tl.load(output_ptr, mask=output_mask)
105
+ tl.store(output_ptr, partial_sum, mask=output_mask)
106
+
107
+
108
+ def qkv_lora_b_fwd(
109
+ x: torch.Tensor,
110
+ qkv_lora_b: torch.Tensor,
111
+ batch_info: LoraBatchInfo,
112
+ output_offset: torch.Tensor,
113
+ max_qkv_out_dim: int,
114
+ base_output: torch.Tensor = None,
115
+ scaling: float = 1.0,
116
+ ) -> torch.Tensor:
117
+
118
+ # x: (s, 3 * r)
119
+ # qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
120
+ # output_offset = [0, output_dim_q, output_dim_q + output_dim_kv,
121
+ # output_dim_q + 2 * output_dim_kv]
122
+ # max_qkv_out_dim = max(output_dim_q, output_dim_kv)
123
+ # output: (s, output_dim_q + 2 * output_dim_kv)
124
+
125
+ # Compute lora_output with shape (s, output_dim) as follows:
126
+ # lora_output[:, :output_dim_q] = sgemm(lora_output_a[:, :r], )
127
+ # lora_output[:, output_dim_q: output_dim_q + output_dim_kv]
128
+ # = sgemm(lora_output_a[:, r: 2 * r], kv_lora_b[0])
129
+ # lora_output[:, output_dim_q + output_dim_kv: ]
130
+ # = sgemm(lora_output_a[:, 2 * r: 3 * r], kv_lora_b[1])
131
+
132
+ # Get dims
133
+ s = x.shape[0]
134
+ input_dim = x.shape[1]
135
+ r = qkv_lora_b.shape[-1]
136
+ output_dim = qkv_lora_b.shape[-2]
137
+ assert input_dim == 3 * r
138
+ assert output_offset.shape[0] == 4
139
+
140
+ BLOCK_S = 16
141
+ BLOCK_R = 16
142
+ BLOCK_OUT = 64
143
+
144
+ grid_b = (
145
+ triton.cdiv(batch_info.max_len, BLOCK_S)
146
+ * triton.cdiv(max_qkv_out_dim, BLOCK_OUT),
147
+ 3, # this dimension decides current block computes on q, k or v
148
+ batch_info.bs,
149
+ )
150
+
151
+ if base_output is None:
152
+ output = torch.empty((s, output_dim), device=x.device, dtype=x.dtype)
153
+ fuse_scaling_add = False
154
+ else:
155
+ output = base_output
156
+ fuse_scaling_add = True
157
+
158
+ _qkv_lora_b_kernel[grid_b](
159
+ x,
160
+ qkv_lora_b,
161
+ output,
162
+ r,
163
+ max_qkv_out_dim,
164
+ x.stride(0),
165
+ x.stride(1),
166
+ qkv_lora_b.stride(0),
167
+ qkv_lora_b.stride(1),
168
+ qkv_lora_b.stride(2),
169
+ output.stride(0),
170
+ output.stride(1),
171
+ batch_info.seg_lens,
172
+ batch_info.seg_indptr,
173
+ batch_info.weight_indices,
174
+ output_offset,
175
+ BLOCK_S,
176
+ BLOCK_OUT,
177
+ BLOCK_R,
178
+ fuse_scaling_add,
179
+ scaling,
180
+ )
181
+
182
+ return output
@@ -0,0 +1,143 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from sglang.srt.lora.lora import LoraBatchInfo
6
+
7
+
8
+ @triton.jit
9
+ def _sgemm_lora_a_kernel(
10
+ # Pointers to matrices
11
+ x,
12
+ weights,
13
+ output,
14
+ # Matrix dimensions
15
+ N, # r
16
+ K, # input_dim
17
+ # Strides
18
+ x_stride_0,
19
+ x_stride_1,
20
+ w_stride_0,
21
+ w_stride_1,
22
+ w_stride_2,
23
+ output_stride_0,
24
+ output_stride_1,
25
+ # Information on sequence lengths and weight id
26
+ seg_lens,
27
+ seg_indptr,
28
+ weight_indices,
29
+ # Meta parameters
30
+ BLOCK_S: tl.constexpr,
31
+ BLOCK_N: tl.constexpr,
32
+ BLOCK_K: tl.constexpr,
33
+ ):
34
+
35
+ # x: (s, K), s is the sum of sequence lengths
36
+ # weights: (num_lora, N, K)
37
+ # output: (s, N)
38
+
39
+ # Current block computes sequence with batch_id,
40
+ # which starts from row seg_start of x with length seg_len
41
+ batch_id = tl.program_id(axis=1)
42
+ pid = tl.program_id(axis=0)
43
+ seg_len = tl.load(seg_lens + batch_id)
44
+ w_index = tl.load(weight_indices + batch_id)
45
+ seg_start = tl.load(seg_indptr + batch_id)
46
+
47
+ # The tile in output matrix will have (pid_s, pid_n) as id
48
+ num_pid_n = tl.cdiv(N, BLOCK_N)
49
+ pid_s = pid // num_pid_n
50
+ pid_n = pid % num_pid_n
51
+
52
+ # Create pointers for the first block of x and weights[batch_id]
53
+ # The pointers will be advanced as we move in the K direction
54
+ # and accumulate
55
+ s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
56
+ n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
57
+ k_offset = tl.arange(0, BLOCK_K)
58
+ x_ptrs = (x + seg_start * x_stride_0) + (
59
+ s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
60
+ )
61
+ w_ptrs = (weights + w_index * w_stride_0) + (
62
+ k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
63
+ )
64
+
65
+ # Iteate to compute the block in output matrix
66
+ partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
67
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
68
+ x_tile = tl.load(
69
+ x_ptrs,
70
+ mask=(s_offset[:, None] < seg_len)
71
+ and (k_offset[None, :] < K - k * BLOCK_K),
72
+ other=0.0,
73
+ )
74
+ w_tile = tl.load(
75
+ w_ptrs,
76
+ mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < N),
77
+ other=0.0,
78
+ )
79
+ partial_sum += tl.dot(x_tile, w_tile)
80
+
81
+ x_ptrs += BLOCK_K * x_stride_1
82
+ w_ptrs += BLOCK_K * w_stride_2
83
+
84
+ # Store result to output matrix
85
+ partial_sum = partial_sum.to(x.dtype.element_ty)
86
+ output_ptr = (output + seg_start * output_stride_0) + (
87
+ s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
88
+ )
89
+ output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < N)
90
+ tl.store(output_ptr, partial_sum, mask=output_mask)
91
+
92
+
93
+ def sgemm_lora_a_fwd(
94
+ x: torch.Tensor, weights: torch.Tensor, batch_info: LoraBatchInfo
95
+ ) -> torch.Tensor:
96
+ # x: (s, input_dim)
97
+ # weights: (num_lora, r, input_dim)
98
+ # output: (s, r)
99
+ # when called by run_qkv_lora, the weights.shape[-2] will be 3 * r
100
+ # input_dim is much larger than r
101
+
102
+ assert x.is_contiguous()
103
+ assert weights.is_contiguous()
104
+ assert len(x.shape) == 2
105
+ assert len(weights.shape) == 3
106
+
107
+ S = x.shape[0]
108
+ R = weights.shape[-2]
109
+ K = weights.shape[-1]
110
+ assert x.shape[-1] == K
111
+
112
+ # Block shapes
113
+ BLOCK_S = 16
114
+ BLOCK_K = 256
115
+ BLOCK_R = 16
116
+
117
+ grid = (
118
+ triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(R, BLOCK_R),
119
+ batch_info.bs,
120
+ )
121
+
122
+ output = torch.empty((S, R), device=x.device, dtype=x.dtype)
123
+ _sgemm_lora_a_kernel[grid](
124
+ x,
125
+ weights,
126
+ output,
127
+ R,
128
+ K,
129
+ x.stride(0),
130
+ x.stride(1),
131
+ weights.stride(0),
132
+ weights.stride(1),
133
+ weights.stride(2),
134
+ output.stride(0),
135
+ output.stride(1),
136
+ batch_info.seg_lens,
137
+ batch_info.seg_indptr,
138
+ batch_info.weight_indices,
139
+ BLOCK_S,
140
+ BLOCK_R,
141
+ BLOCK_K,
142
+ )
143
+ return output
@@ -0,0 +1,159 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from sglang.srt.lora.lora import LoraBatchInfo
6
+
7
+
8
+ @triton.jit
9
+ def _sgemm_lora_b_kernel(
10
+ # Pointers to matrices
11
+ x,
12
+ weights,
13
+ output,
14
+ # Matrix dimensions
15
+ N, # output_dim
16
+ K, # r
17
+ # Strides
18
+ x_stride_0,
19
+ x_stride_1,
20
+ w_stride_0,
21
+ w_stride_1,
22
+ w_stride_2,
23
+ output_stride_0,
24
+ output_stride_1,
25
+ # Information on sequence lengths and weight id
26
+ seg_lens,
27
+ seg_indptr,
28
+ weight_indices,
29
+ # Meta parameters
30
+ BLOCK_S: tl.constexpr,
31
+ BLOCK_N: tl.constexpr,
32
+ BLOCK_K: tl.constexpr,
33
+ # For fused output scaling and adding
34
+ fuse_scaling_add,
35
+ scaling,
36
+ ):
37
+ # x: (s, K), s is the sum of sequence lengths
38
+ # weights: (num_lora, N, K)
39
+ # output: (s, N)
40
+
41
+ # Current block computes sequence with batch_id,
42
+ # which starts from row seg_start of x with length seg_len
43
+ batch_id = tl.program_id(axis=1)
44
+ pid = tl.program_id(axis=0)
45
+ seg_len = tl.load(seg_lens + batch_id)
46
+ w_index = tl.load(weight_indices + batch_id)
47
+ seg_start = tl.load(seg_indptr + batch_id)
48
+
49
+ # The tile in output matrix will have (pid_s, pid_n) as id
50
+ num_pid_n = tl.cdiv(N, BLOCK_N)
51
+ pid_s = pid // num_pid_n
52
+ pid_n = pid % num_pid_n
53
+
54
+ # Create pointers for the first block of x and weights[batch_id]
55
+ # The pointers will be advanced as we move in the K direction
56
+ # and accumulate
57
+ s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
58
+ n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
59
+ k_offset = tl.arange(0, BLOCK_K)
60
+ x_ptrs = (x + seg_start * x_stride_0) + (
61
+ s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
62
+ )
63
+ w_ptrs = (weights + w_index * w_stride_0) + (
64
+ k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
65
+ )
66
+
67
+ # Iteate to compute the block in output matrix
68
+ partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
69
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
70
+ x_tile = tl.load(
71
+ x_ptrs,
72
+ mask=(s_offset[:, None] < seg_len)
73
+ and (k_offset[None, :] < K - k * BLOCK_K),
74
+ other=0.0,
75
+ )
76
+ w_tile = tl.load(
77
+ w_ptrs,
78
+ mask=(k_offset[:, None] < K - k * BLOCK_K),
79
+ other=0.0,
80
+ )
81
+ partial_sum += tl.dot(x_tile, w_tile)
82
+
83
+ x_ptrs += BLOCK_K * x_stride_1
84
+ w_ptrs += BLOCK_K * w_stride_2
85
+
86
+ # Store result to output matrix
87
+ partial_sum *= scaling
88
+ partial_sum = partial_sum.to(x.dtype.element_ty)
89
+ output_ptr = (output + seg_start * output_stride_0) + (
90
+ s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
91
+ )
92
+ output_mask = s_offset[:, None] < seg_len
93
+ if fuse_scaling_add:
94
+ partial_sum += tl.load(output_ptr, mask=output_mask)
95
+ tl.store(output_ptr, partial_sum, mask=output_mask)
96
+
97
+
98
+ def sgemm_lora_b_fwd(
99
+ x: torch.Tensor,
100
+ weights: torch.Tensor,
101
+ batch_info: LoraBatchInfo,
102
+ base_output: torch.Tensor = None,
103
+ scaling: float = 1.0,
104
+ ) -> torch.Tensor:
105
+ # x: (s, r)
106
+ # weights: (num_lora, output_dim, r)
107
+ # output: (s, output_dim)
108
+ # output_dim is much larger than r
109
+
110
+ assert x.is_contiguous()
111
+ assert weights.is_contiguous()
112
+ assert len(x.shape) == 2
113
+ assert len(weights.shape) == 3
114
+
115
+ S = x.shape[0]
116
+ N = weights.shape[-2]
117
+ R = weights.shape[-1]
118
+ assert x.shape[-1] == R
119
+
120
+ # Block shapes
121
+ BLOCK_S = 16
122
+ BLOCK_R = 16
123
+ BLOCK_N = 256
124
+
125
+ grid = (
126
+ triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(N, BLOCK_N),
127
+ batch_info.bs,
128
+ )
129
+
130
+ if base_output is None:
131
+ output = torch.empty((S, N), device=x.device, dtype=x.dtype)
132
+ fuse_scaling_add = False
133
+ else:
134
+ output = base_output
135
+ fuse_scaling_add = True
136
+
137
+ _sgemm_lora_b_kernel[grid](
138
+ x,
139
+ weights,
140
+ output,
141
+ N,
142
+ R,
143
+ x.stride(0),
144
+ x.stride(1),
145
+ weights.stride(0),
146
+ weights.stride(1),
147
+ weights.stride(2),
148
+ output.stride(0),
149
+ output.stride(1),
150
+ batch_info.seg_lens,
151
+ batch_info.seg_indptr,
152
+ batch_info.weight_indices,
153
+ BLOCK_S,
154
+ BLOCK_N,
155
+ BLOCK_R,
156
+ fuse_scaling_add,
157
+ scaling,
158
+ )
159
+ return output
@@ -240,6 +240,7 @@ class MllamaImageProcessor(BaseImageProcessor):
240
240
  class MiniCPMVImageProcessor(BaseImageProcessor):
241
241
  def __init__(self, hf_config, server_args, _processor):
242
242
  super().__init__(hf_config, server_args, _processor)
243
+ self.IMAGE_TOKEN = "(<image>./</image>)"
243
244
 
244
245
  @staticmethod
245
246
  def _process_images_task(images, input_text):
@@ -271,7 +272,7 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
271
272
  async def process_images_async(
272
273
  self,
273
274
  image_data: List[Union[str, bytes]],
274
- input_text,
275
+ input_ids,
275
276
  request_obj,
276
277
  max_req_input_len,
277
278
  ):
@@ -282,28 +283,49 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
282
283
  image_data = [image_data]
283
284
 
284
285
  image_hashes, image_sizes = [], []
285
- raw_images = []
286
- IMAGE_TOKEN = "(<image>./</image>)"
286
+ all_frames = []
287
287
 
288
- # roughly calculate the max number of frames
289
- # TODO: the process should be applied to all the visual inputs
288
+ # roughly calculate the max number of frames under the max_req_input_len limit
290
289
  def calculate_max_num_frames() -> int:
291
290
  # Model-specific
292
291
  NUM_TOKEN_PER_FRAME = 330
293
292
 
294
- ret = (max_req_input_len - len(input_text)) // NUM_TOKEN_PER_FRAME
293
+ ret = (max_req_input_len - len(input_ids)) // NUM_TOKEN_PER_FRAME
295
294
  return min(ret, 100)
296
295
 
297
- # if cuda OOM set a smaller number
298
296
  MAX_NUM_FRAMES = calculate_max_num_frames()
299
- print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
300
297
 
301
- def encode_video(video_path):
298
+ # print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
299
+
300
+ def get_estimated_frames_list():
301
+ """
302
+ estimate the total frame count from all visual input
303
+ """
304
+ # Before processing inputs
305
+ estimated_frames_list = []
306
+ for image in image_data:
307
+ if isinstance(image, str) and image.startswith("video:"):
308
+ path = image[len("video:") :]
309
+ # Estimate frames for the video
310
+ vr = VideoReader(path, ctx=cpu(0))
311
+ num_frames = len(vr)
312
+ else:
313
+ # For images, each contributes one frame
314
+ num_frames = 1
315
+ estimated_frames_list.append(num_frames)
316
+
317
+ return estimated_frames_list
318
+
319
+ estimated_frames_list = get_estimated_frames_list()
320
+ total_frame_count = sum(estimated_frames_list)
321
+ scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
322
+
323
+ def encode_video(video_path, frame_count_limit=None):
302
324
  if not os.path.exists(video_path):
303
325
  logger.error(f"Video {video_path} does not exist")
304
326
  return []
305
327
 
306
- if MAX_NUM_FRAMES == 0:
328
+ if frame_count_limit == 0:
307
329
  return []
308
330
 
309
331
  def uniform_sample(l, n):
@@ -314,45 +336,63 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
314
336
  vr = VideoReader(video_path, ctx=cpu(0))
315
337
  sample_fps = round(vr.get_avg_fps() / 1) # FPS
316
338
  frame_idx = [i for i in range(0, len(vr), sample_fps)]
317
- if len(frame_idx) > MAX_NUM_FRAMES:
318
- frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
339
+ if frame_count_limit is not None and len(frame_idx) > frame_count_limit:
340
+ frame_idx = uniform_sample(frame_idx, frame_count_limit)
319
341
  frames = vr.get_batch(frame_idx).asnumpy()
320
342
  frames = [Image.fromarray(v.astype("uint8")) for v in frames]
321
343
  return frames
322
344
 
323
- if isinstance(input_text, list):
324
- assert len(input_text) and isinstance(input_text[0], int)
325
- input_text = self._processor.tokenizer.decode(input_text)
326
-
345
+ if isinstance(input_ids, list):
346
+ assert len(input_ids) and isinstance(input_ids[0], int)
347
+ input_text = self._processor.tokenizer.decode(input_ids)
348
+ else:
349
+ input_text = input_ids
327
350
  # MiniCPMV requires each frame of video as a single image token
328
- text_parts = input_text.split(IMAGE_TOKEN)
351
+ text_parts = input_text.split(self.IMAGE_TOKEN)
329
352
  new_text_parts = []
330
353
 
331
- for image_index, image in enumerate(image_data):
332
- try:
333
- if isinstance(image, str) and image.startswith("video:"):
334
- path = image[len("video:") :]
335
- frames = encode_video(path)
336
- else:
337
- raw_image, size = load_image(image)
338
- frames = [raw_image]
339
- if len(frames) == 0:
340
- continue
341
- except FileNotFoundError as e:
342
- print(e)
343
- return None
344
-
345
- image_sizes += frames[0].size * len(frames)
346
- image_hashes += [hash(image)] * len(frames)
347
- raw_images += frames
354
+ # Process each input with allocated frames
355
+ for image_index, (image, estimated_frames) in enumerate(
356
+ zip(image_data, estimated_frames_list)
357
+ ):
358
+ if len(all_frames) >= MAX_NUM_FRAMES:
359
+ frames_to_process = 0
360
+ else:
361
+ frames_to_process = max(1, int(estimated_frames * scaling_factor))
362
+
363
+ if frames_to_process == 0:
364
+ frames = []
365
+ else:
366
+ try:
367
+ if isinstance(image, str) and image.startswith("video:"):
368
+ path = image[len("video:") :]
369
+ frames = encode_video(path, frame_count_limit=frames_to_process)
370
+ else:
371
+ raw_image, _size = load_image(image)
372
+ frames = [raw_image]
373
+ if len(frames) == 0:
374
+ continue
375
+ except FileNotFoundError as e:
376
+ print(e)
377
+ return None
378
+ image_sizes += frames[0].size * len(frames)
379
+ image_hashes += [hash(image)] * len(frames)
380
+ all_frames += frames
381
+
382
+ assert frames_to_process == len(frames)
383
+
348
384
  new_text_parts.append(text_parts[image_index])
349
- new_text_parts.append(IMAGE_TOKEN * len(frames))
385
+
386
+ if frames_to_process != 0:
387
+ new_text_parts.append(self.IMAGE_TOKEN * len(frames))
350
388
 
351
389
  new_text_parts.append(text_parts[-1])
390
+
352
391
  input_text = "".join(new_text_parts)
353
- if len(raw_images) == 0:
392
+
393
+ if len(all_frames) == 0:
354
394
  return None
355
- res = await self._process_images(images=raw_images, input_text=input_text)
395
+ res = await self._process_images(images=all_frames, input_text=input_text)
356
396
  pixel_values = res["pixel_values"]
357
397
  tgt_sizes = res["tgt_sizes"]
358
398
  input_ids = res["input_ids"]
@@ -364,7 +404,6 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
364
404
  if tokenizer.slice_start_id:
365
405
  slice_start_id = [tokenizer.slice_start_id]
366
406
  slice_end_id = [tokenizer.slice_end_id]
367
-
368
407
  return {
369
408
  "input_ids": input_ids.flatten().tolist(),
370
409
  "pixel_values": pixel_values,
@@ -149,6 +149,7 @@ class Scheduler:
149
149
  if not self.spec_algorithm.is_none()
150
150
  else 1
151
151
  )
152
+ self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
152
153
 
153
154
  # Distributed rank info
154
155
  self.dp_size = server_args.dp_size
@@ -831,10 +832,16 @@ class Scheduler:
831
832
  available_size = (
832
833
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
833
834
  )
834
- if available_size != self.max_total_num_tokens:
835
+ protected_size = self.tree_cache.protected_size()
836
+ memory_leak = available_size != (
837
+ self.max_total_num_tokens
838
+ if not self.enable_hierarchical_cache
839
+ else self.max_total_num_tokens - protected_size
840
+ )
841
+ if memory_leak:
835
842
  msg = (
836
843
  "KV cache pool leak detected!"
837
- f"{available_size=}, {self.max_total_num_tokens=}\n"
844
+ f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
838
845
  )
839
846
  warnings.warn(msg)
840
847
  if crash_on_warnings():
@@ -949,7 +956,14 @@ class Scheduler:
949
956
  res = adder.add_one_req(req)
950
957
  if res != AddReqResult.CONTINUE:
951
958
  if res == AddReqResult.NO_TOKEN:
952
- self.batch_is_full = True
959
+ if self.enable_hierarchical_cache:
960
+ # Set batch_is_full after making sure there are requests that can be served
961
+ self.batch_is_full = len(adder.can_run_list) > 0 or (
962
+ self.running_batch is not None
963
+ and not self.running_batch.is_empty()
964
+ )
965
+ else:
966
+ self.batch_is_full = True
953
967
  break
954
968
  if self.server_args.prefill_only_one_req:
955
969
  break