sglang 0.4.4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,342 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from sglang.srt.layers.moe.topk import fused_topk
8
+
9
+
10
+ @triton.jit
11
+ def fused_moe_router_kernel(
12
+ input_ptr, # input (bs, hidden_dim)
13
+ moe_router_weight_ptr, # input (num_experts, hidden_dim)
14
+ topk_weights_ptr, # output (bs, topk)
15
+ topk_ids_ptr, # output (bs, topk)
16
+ num_experts: tl.constexpr,
17
+ topk: tl.constexpr,
18
+ moe_softcapping: tl.constexpr,
19
+ moe_renormalize: tl.constexpr, # not supported
20
+ hidden_dim: tl.constexpr,
21
+ BLOCK_SIZE: tl.constexpr,
22
+ ):
23
+ pid = tl.program_id(axis=0)
24
+
25
+ offsets = tl.arange(0, BLOCK_SIZE)
26
+ mask = offsets < hidden_dim
27
+
28
+ # moe_router_weight is k major
29
+ expert_offsets = tl.arange(0, num_experts)[:, None]
30
+ router_mask = mask[None, :]
31
+ w_router = tl.load(
32
+ moe_router_weight_ptr + expert_offsets * hidden_dim + offsets[None, :],
33
+ mask=router_mask,
34
+ other=0.0,
35
+ )
36
+
37
+ x = tl.load(input_ptr + pid * hidden_dim + offsets, mask=mask, other=0.0)
38
+
39
+ # todo: tl.dot?
40
+ logits = tl.sum((w_router.to(tl.float32) * x[None, :].to(tl.float32)), axis=-1)
41
+
42
+ # logit softcap
43
+ logits_scaled = logits / moe_softcapping
44
+ exped = tl.exp(2 * logits_scaled)
45
+ top = exped - 1
46
+ bottom = exped + 1
47
+ logits_softcapped = top / bottom * moe_softcapping
48
+
49
+ # topk
50
+ # assert 1 <= topk <= num_experts
51
+
52
+ # 5.38 us
53
+
54
+ top1 = tl.argmax(logits_softcapped, axis=0)
55
+ tl.store(topk_ids_ptr + pid * topk + 0, top1) # 5.63 us
56
+
57
+ top1_v = tl.max(logits_softcapped, axis=0)
58
+ invsumexp = 1.0 / tl.sum(tl.exp(logits_softcapped - top1_v), axis=0)
59
+
60
+ tl.store(
61
+ topk_weights_ptr + pid * topk + 0,
62
+ invsumexp,
63
+ ) # 5.73 us
64
+
65
+ if topk >= 2:
66
+ top2 = tl.argmax(
67
+ tl.where(
68
+ tl.arange(0, num_experts) != top1, logits_softcapped, float("-inf")
69
+ ),
70
+ axis=0,
71
+ )
72
+ tl.store(topk_ids_ptr + pid * topk + 1, top2)
73
+ top2_v = tl.sum(logits_softcapped * (tl.arange(0, num_experts) == top2), axis=0)
74
+ tl.store(
75
+ topk_weights_ptr + pid * topk + 1,
76
+ tl.exp(top2_v - top1_v) * invsumexp,
77
+ ) # 5.95us
78
+
79
+ # probably slow
80
+ if topk > 2:
81
+ topk_mask = tl.full(logits_softcapped.shape, 1.0, dtype=logits_softcapped.dtype)
82
+ topk_mask = tl.where(
83
+ tl.arange(0, num_experts) != top1, topk_mask, float("-inf")
84
+ )
85
+ topk_mask = tl.where(
86
+ tl.arange(0, num_experts) != top2, topk_mask, float("-inf")
87
+ )
88
+ for i in range(2, topk):
89
+ topi = tl.argmax(logits_softcapped + topk_mask, axis=0)
90
+ topk_mask = tl.where(
91
+ tl.arange(0, num_experts) != topi, topk_mask, float("-inf")
92
+ )
93
+ tl.store(topk_ids_ptr + pid * topk + i, topi)
94
+ topi_v = tl.sum(
95
+ logits_softcapped * (tl.arange(0, num_experts) == topi), axis=0
96
+ )
97
+ tl.store(
98
+ topk_weights_ptr + pid * topk + i,
99
+ tl.exp(topi_v - top1_v) * invsumexp,
100
+ )
101
+ # assert not moe_renormalize, "moe weight renormalization not implemented"
102
+
103
+
104
+ def fused_moe_router_impl(
105
+ x: torch.Tensor,
106
+ router_weight: torch.Tensor,
107
+ topk: int,
108
+ moe_softcapping: float,
109
+ ):
110
+ assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
111
+ bs, hidden_dim = x.shape
112
+ num_experts = router_weight.shape[0]
113
+
114
+ # router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device)
115
+ topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
116
+ topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
117
+
118
+ grid = lambda meta: (bs,)
119
+ config = {
120
+ "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
121
+ "num_warps": max(
122
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
123
+ ),
124
+ }
125
+
126
+ fused_moe_router_kernel[grid](
127
+ x,
128
+ router_weight,
129
+ topk_weights,
130
+ topk_ids,
131
+ num_experts=num_experts,
132
+ topk=topk,
133
+ moe_softcapping=moe_softcapping,
134
+ moe_renormalize=False,
135
+ hidden_dim=hidden_dim,
136
+ **config,
137
+ )
138
+
139
+ return topk_weights, topk_ids
140
+
141
+
142
+ @triton.jit
143
+ def fused_moe_router_large_bs_kernel(
144
+ a_ptr, # input (bs, hidden_dim)
145
+ b_ptr, # input (num_experts, hidden_dim)
146
+ topk_weights_ptr, # output (bs, topk)
147
+ topk_ids_ptr, # output (bs, topk)
148
+ bs,
149
+ num_experts: tl.constexpr,
150
+ topk: tl.constexpr, # only support topk == 1
151
+ moe_softcapping: tl.constexpr,
152
+ moe_renormalize: tl.constexpr, # not supported
153
+ K: tl.constexpr,
154
+ BLOCK_SIZE_M: tl.constexpr,
155
+ BLOCK_SIZE_N: tl.constexpr,
156
+ BLOCK_SIZE_K: tl.constexpr,
157
+ stride_am: tl.constexpr,
158
+ stride_bn: tl.constexpr,
159
+ ):
160
+
161
+ # 1. get block id
162
+ pid = tl.program_id(axis=0)
163
+
164
+ # 2. create pointers for the first block of A and B
165
+ # 2.1. setup a_ptrs with offsets in m and k
166
+ offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[:, None]
167
+ bs_mask = offs_m < bs
168
+ offs_k = tl.arange(0, BLOCK_SIZE_K)[None, :]
169
+ a_ptrs = a_ptr + (offs_m * stride_am + offs_k)
170
+
171
+ # 2.2. setup b_ptrs with offsets in k and n.
172
+ # Note: b matrix is k-major.
173
+ offs_k = tl.arange(0, BLOCK_SIZE_K)[None, :]
174
+ offs_n = tl.arange(0, BLOCK_SIZE_N)[:, None]
175
+ expert_mask = offs_n < num_experts
176
+ b_ptrs = b_ptr + (offs_n * stride_bn + offs_k)
177
+
178
+ # 3. Create an accumulator of float32 of size [BLOCK_SIZE_M, BLOCK_SIZE_N]
179
+ # 3.1. iterate in K dimension
180
+ # 3.2. transpose tile B
181
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
182
+ for k in range(0, K // BLOCK_SIZE_K): # hidden_dim % BLOCK_SIZE_K == 0
183
+ a = tl.load(
184
+ a_ptrs,
185
+ mask=bs_mask,
186
+ other=0.0,
187
+ ).to(tl.float32)
188
+ b = tl.load(b_ptrs, mask=expert_mask, other=0.0).to(tl.float32).T
189
+ acc += tl.dot(a, b)
190
+
191
+ # Advance the ptrs to the next K block.
192
+ a_ptrs += BLOCK_SIZE_K
193
+ b_ptrs += BLOCK_SIZE_K
194
+
195
+ # 4. logit softcap
196
+ logits_scaled = acc / moe_softcapping
197
+ exped = tl.exp(2 * logits_scaled)
198
+ logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
199
+
200
+ # 5. top1
201
+ cond = tl.arange(0, BLOCK_SIZE_N)[None, :] < num_experts
202
+ top1 = tl.argmax(tl.where(cond, logits_softcapped, float("-inf")), axis=1)
203
+ top1_v = tl.max(
204
+ tl.where(cond, logits_softcapped, float("-inf")), axis=1, keep_dims=True
205
+ )
206
+ invsumexp = 1.0 / tl.sum(
207
+ tl.where(cond, tl.exp(logits_softcapped - top1_v), 0.0), axis=1
208
+ )
209
+
210
+ # 6. store to output
211
+ offs_topk = pid * topk * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
212
+ topk_mask = offs_topk < bs
213
+ tl.store(topk_ids_ptr + offs_topk, top1, mask=topk_mask)
214
+ tl.store(
215
+ topk_weights_ptr + offs_topk,
216
+ invsumexp,
217
+ mask=topk_mask,
218
+ )
219
+
220
+
221
+ def fused_moe_router_large_bs_impl(
222
+ x: torch.Tensor,
223
+ router_weight: torch.Tensor,
224
+ topk: int,
225
+ moe_softcapping: float,
226
+ BLOCK_SIZE_M: int,
227
+ BLOCK_SIZE_N: int,
228
+ BLOCK_SIZE_K: int,
229
+ ):
230
+ assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
231
+ bs, hidden_dim = x.shape
232
+ num_experts = router_weight.shape[0]
233
+
234
+ assert num_experts <= BLOCK_SIZE_N
235
+ assert hidden_dim % BLOCK_SIZE_K == 0
236
+ assert topk == 1
237
+
238
+ topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
239
+ topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
240
+
241
+ grid = (triton.cdiv(bs, BLOCK_SIZE_M) * triton.cdiv(num_experts, BLOCK_SIZE_N),)
242
+
243
+ fused_moe_router_large_bs_kernel[grid](
244
+ a_ptr=x,
245
+ b_ptr=router_weight,
246
+ topk_weights_ptr=topk_weights,
247
+ topk_ids_ptr=topk_ids,
248
+ bs=bs,
249
+ num_experts=num_experts,
250
+ topk=topk,
251
+ moe_softcapping=moe_softcapping,
252
+ moe_renormalize=False,
253
+ K=hidden_dim,
254
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
255
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
256
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
257
+ stride_am=hidden_dim,
258
+ stride_bn=hidden_dim,
259
+ )
260
+
261
+ return topk_weights, topk_ids
262
+
263
+
264
+ def fused_moe_router_shim(
265
+ moe_softcapping,
266
+ hidden_states,
267
+ gating_output,
268
+ topk,
269
+ renormalize,
270
+ ):
271
+ assert not renormalize
272
+ assert (
273
+ len(hidden_states.shape) == 2
274
+ and hidden_states.shape[1] == gating_output.shape[1]
275
+ )
276
+ bs, hidden_dim = hidden_states.shape
277
+ num_experts = gating_output.shape[0]
278
+ BLOCK_SIZE_M = 32
279
+ BLOCK_SIZE_N = 16
280
+ BLOCK_SIZE_K = 256
281
+ if (
282
+ bs >= 512
283
+ and topk == 1
284
+ and num_experts <= BLOCK_SIZE_N
285
+ and hidden_dim % BLOCK_SIZE_K == 0
286
+ ):
287
+ return fused_moe_router_large_bs_impl(
288
+ x=hidden_states,
289
+ router_weight=gating_output,
290
+ topk=topk,
291
+ moe_softcapping=moe_softcapping,
292
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
293
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
294
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
295
+ )
296
+ else:
297
+ return fused_moe_router_impl(
298
+ x=hidden_states,
299
+ router_weight=gating_output,
300
+ topk=topk,
301
+ moe_softcapping=moe_softcapping,
302
+ )
303
+
304
+
305
+ class FusedMoeRouter:
306
+ def __init__(self, router_linear, topk, moe_softcapping) -> None:
307
+ self.router_linear = router_linear
308
+ self.topk = topk
309
+ self.moe_softcapping = moe_softcapping
310
+
311
+ def __call__(self, *args, **kwargs):
312
+ return self.forward(*args, **kwargs)
313
+
314
+ def forward(
315
+ self, x: torch.Tensor, residual: torch.Tensor
316
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
317
+ if x.is_cuda:
318
+ return self.forward_cuda(x, residual)
319
+ else:
320
+ return self.forward_vllm(x, residual)
321
+
322
+ def forward_cuda(
323
+ self, x: torch.Tensor, autotune=False
324
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
325
+ return fused_moe_router_shim(
326
+ moe_softcapping=self.moe_softcapping,
327
+ hidden_states=x,
328
+ gating_output=self.router_linear.weight,
329
+ topk=self.topk,
330
+ renormalize=False,
331
+ )
332
+
333
+ def forward_vllm(
334
+ self,
335
+ x: torch.Tensor,
336
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
337
+ # g, _ = self.router_linear.forward(x)
338
+ g = x.float() @ self.router_linear.weight.T.float()
339
+
340
+ g = torch.tanh(g.float() / self.moe_softcapping) * self.moe_softcapping
341
+
342
+ return fused_topk(x, g, self.topk, False)
@@ -248,6 +248,8 @@ class HiCacheController:
248
248
  if device_indices is None:
249
249
  return None
250
250
  self.mem_pool_host.protect_load(host_indices)
251
+ # to ensure the device indices are ready before accessed by another CUDA stream
252
+ torch.cuda.current_stream().synchronize()
251
253
  self.load_queue.put(
252
254
  CacheOperation(host_indices, device_indices, node_id, priority)
253
255
  )
@@ -54,7 +54,7 @@ class LoadBalanceMethod(Enum):
54
54
  class DataParallelController:
55
55
  """A controller that dispatches requests to multiple data parallel workers."""
56
56
 
57
- def __init__(self, server_args, port_args) -> None:
57
+ def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
58
58
  # Parse args
59
59
  self.max_total_num_tokens = None
60
60
  self.server_args = server_args
@@ -361,7 +361,7 @@ class Req:
361
361
  ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
362
362
  self.output_token_ids_logprobs_idx
363
363
  ) = None
364
- self.hidden_states = []
364
+ self.hidden_states: List[List[float]] = []
365
365
 
366
366
  # Embedding (return values)
367
367
  self.embedding = None
@@ -434,6 +434,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
434
434
  req_to_token_pool=self.req_to_token_pool,
435
435
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
436
436
  tp_cache_group=self.tp_worker.get_tp_cpu_group(),
437
+ page_size=self.page_size,
437
438
  )
438
439
  else:
439
440
  self.tree_cache = RadixCache(
@@ -997,7 +998,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
997
998
 
998
999
  # Handle DP attention
999
1000
  if self.server_args.enable_dp_attention:
1000
- ret = self.prepare_dp_attn_batch(ret)
1001
+ ret, _ = self.prepare_dp_attn_batch(ret)
1001
1002
 
1002
1003
  return ret
1003
1004
 
@@ -1269,39 +1270,72 @@ class Scheduler(SchedulerOutputProcessorMixin):
1269
1270
  # Check if other DP workers have running batches
1270
1271
  if local_batch is None:
1271
1272
  num_tokens = 0
1273
+ global_num_tokens_for_logprob = 0
1272
1274
  elif local_batch.forward_mode.is_decode():
1273
1275
  num_tokens = local_batch.batch_size()
1276
+ if not self.spec_algorithm.is_none() and self.spec_algorithm.is_eagle():
1277
+ num_tokens = num_tokens * self.server_args.speculative_num_draft_tokens
1278
+ global_num_tokens_for_logprob = num_tokens
1274
1279
  else:
1275
1280
  num_tokens = local_batch.extend_num_tokens
1281
+ global_num_tokens_for_logprob = sum(
1282
+ [
1283
+ # We should have at least 1 token for sample in every case.
1284
+ max(extend_len - logprob_start_len, 1)
1285
+ for logprob_start_len, extend_len in zip(
1286
+ local_batch.extend_logprob_start_lens, local_batch.extend_lens
1287
+ )
1288
+ ]
1289
+ )
1290
+
1291
+ if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
1292
+ can_cuda_graph = 1
1293
+ else:
1294
+ can_cuda_graph = 0
1295
+
1296
+ if not self.spec_algorithm.is_none():
1297
+ # TODO(sang): Support cuda graph when idle batch is there.
1298
+ if local_batch is None or local_batch.forward_mode.is_idle():
1299
+ can_cuda_graph = 0
1276
1300
 
1277
- local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
1278
- global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
1301
+ is_extend_in_batch = (
1302
+ local_batch.forward_mode.is_extend() if local_batch else False
1303
+ )
1304
+ local_info = torch.tensor(
1305
+ [
1306
+ num_tokens,
1307
+ can_cuda_graph,
1308
+ global_num_tokens_for_logprob,
1309
+ is_extend_in_batch,
1310
+ ],
1311
+ dtype=torch.int64,
1312
+ )
1313
+ global_info = torch.empty(
1314
+ (self.server_args.dp_size, self.attn_tp_size, 4),
1315
+ dtype=torch.int64,
1316
+ )
1279
1317
  torch.distributed.all_gather_into_tensor(
1280
- global_num_tokens,
1281
- local_num_tokens,
1318
+ global_info.flatten(),
1319
+ local_info,
1282
1320
  group=self.tp_cpu_group,
1283
1321
  )
1322
+ global_num_tokens = global_info[:, 0, 0].tolist()
1323
+ can_cuda_graph = min(global_info[:, 0, 1].tolist())
1324
+ global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
1325
+ is_extend_in_batch = global_info[:, 0, 3].tolist()
1284
1326
 
1285
- if local_batch is None and global_num_tokens.max().item() > 0:
1327
+ if local_batch is None and max(global_num_tokens) > 0:
1286
1328
  local_batch = self.get_idle_batch()
1287
1329
 
1288
1330
  if local_batch is not None:
1289
- local_batch.global_num_tokens = global_num_tokens.tolist()
1331
+ local_batch.global_num_tokens = global_num_tokens
1332
+ local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
1290
1333
 
1291
1334
  # Check forward mode for cuda graph
1292
1335
  if not self.server_args.disable_cuda_graph:
1293
- forward_mode_state = torch.tensor(
1294
- (1 if local_batch.forward_mode.is_decode_or_idle() else 0),
1295
- dtype=torch.int32,
1296
- )
1297
- torch.distributed.all_reduce(
1298
- forward_mode_state,
1299
- op=torch.distributed.ReduceOp.MIN,
1300
- group=self.tp_cpu_group,
1301
- )
1302
- local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
1336
+ local_batch.can_run_dp_cuda_graph = can_cuda_graph
1303
1337
 
1304
- return local_batch
1338
+ return local_batch, any(is_extend_in_batch)
1305
1339
 
1306
1340
  def get_idle_batch(self):
1307
1341
  idle_batch = ScheduleBatch.init_new(
@@ -111,6 +111,7 @@ class SchedulerOutputProcessorMixin:
111
111
  ]
112
112
  .cpu()
113
113
  .clone()
114
+ .tolist()
114
115
  )
115
116
 
116
117
  if req.grammar is not None:
@@ -245,7 +246,9 @@ class SchedulerOutputProcessorMixin:
245
246
  )
246
247
 
247
248
  if req.return_hidden_states and logits_output.hidden_states is not None:
248
- req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
249
+ req.hidden_states.append(
250
+ logits_output.hidden_states[i].cpu().clone().tolist()
251
+ )
249
252
 
250
253
  if req.grammar is not None and batch.spec_algorithm.is_none():
251
254
  req.grammar.accept_token(next_token_id)
@@ -25,11 +25,17 @@ class HiRadixCache(RadixCache):
25
25
  req_to_token_pool: ReqToTokenPool,
26
26
  token_to_kv_pool_allocator: TokenToKVPoolAllocator,
27
27
  tp_cache_group: torch.distributed.ProcessGroup,
28
+ page_size: int,
28
29
  ):
30
+ if page_size != 1:
31
+ raise ValueError(
32
+ "Page size larger than 1 is not yet supported in HiRadixCache."
33
+ )
29
34
  self.token_to_kv_pool_host = MHATokenToKVPoolHost(
30
35
  token_to_kv_pool_allocator.get_kvcache()
31
36
  )
32
37
  self.tp_group = tp_cache_group
38
+ self.page_size = page_size
33
39
 
34
40
  self.load_cache_event = threading.Event()
35
41
  self.cache_controller = HiCacheController(
@@ -45,7 +51,9 @@ class HiRadixCache(RadixCache):
45
51
  # todo: dynamically adjust the threshold
46
52
  self.write_through_threshold = 1
47
53
  self.load_back_threshold = 10
48
- super().__init__(req_to_token_pool, token_to_kv_pool_allocator, disable=False)
54
+ super().__init__(
55
+ req_to_token_pool, token_to_kv_pool_allocator, self.page_size, disable=False
56
+ )
49
57
 
50
58
  def reset(self):
51
59
  TreeNode.counter = 0
@@ -326,7 +326,7 @@ class MHATokenToKVPool(KVCache):
326
326
  cache_k = cache_k.view(self.store_dtype)
327
327
  cache_v = cache_v.view(self.store_dtype)
328
328
 
329
- if self.capture_mode:
329
+ if self.capture_mode and cache_k.shape[0] < 4:
330
330
  self.alt_stream.wait_stream(torch.cuda.current_stream())
331
331
  with torch.cuda.stream(self.alt_stream):
332
332
  self.k_buffer[layer_id][loc] = cache_k
@@ -591,6 +591,9 @@ class MHATokenToKVPoolHost:
591
591
  def get_flat_data(self, indices):
592
592
  return self.kv_buffer[:, :, indices]
593
593
 
594
+ def get_flat_data_by_layer(self, indices, layer_id):
595
+ return self.kv_buffer[:, layer_id, indices]
596
+
594
597
  def assign_flat_data(self, indices, flat_data):
595
598
  self.kv_buffer[:, :, indices] = flat_data
596
599