sglang 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. sglang/__init__.py +5 -1
  2. sglang/api.py +8 -3
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +148 -12
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/global_config.py +11 -1
  8. sglang/lang/chat_template.py +9 -2
  9. sglang/lang/interpreter.py +161 -81
  10. sglang/lang/ir.py +29 -11
  11. sglang/lang/tracer.py +1 -1
  12. sglang/launch_server.py +1 -2
  13. sglang/launch_server_llavavid.py +31 -0
  14. sglang/srt/constrained/fsm_cache.py +3 -0
  15. sglang/srt/flush_cache.py +16 -0
  16. sglang/srt/hf_transformers_utils.py +83 -2
  17. sglang/srt/layers/extend_attention.py +17 -0
  18. sglang/srt/layers/fused_moe.py +485 -0
  19. sglang/srt/layers/logits_processor.py +12 -7
  20. sglang/srt/layers/radix_attention.py +10 -3
  21. sglang/srt/layers/token_attention.py +16 -1
  22. sglang/srt/managers/controller/dp_worker.py +110 -0
  23. sglang/srt/managers/controller/infer_batch.py +619 -0
  24. sglang/srt/managers/controller/manager_multi.py +191 -0
  25. sglang/srt/managers/controller/manager_single.py +97 -0
  26. sglang/srt/managers/controller/model_runner.py +462 -0
  27. sglang/srt/managers/controller/radix_cache.py +267 -0
  28. sglang/srt/managers/controller/schedule_heuristic.py +59 -0
  29. sglang/srt/managers/controller/tp_worker.py +791 -0
  30. sglang/srt/managers/detokenizer_manager.py +45 -45
  31. sglang/srt/managers/io_struct.py +26 -10
  32. sglang/srt/managers/router/infer_batch.py +130 -74
  33. sglang/srt/managers/router/manager.py +7 -9
  34. sglang/srt/managers/router/model_rpc.py +224 -135
  35. sglang/srt/managers/router/model_runner.py +94 -107
  36. sglang/srt/managers/router/radix_cache.py +54 -18
  37. sglang/srt/managers/router/scheduler.py +23 -34
  38. sglang/srt/managers/tokenizer_manager.py +183 -88
  39. sglang/srt/model_config.py +5 -2
  40. sglang/srt/models/commandr.py +15 -22
  41. sglang/srt/models/dbrx.py +22 -29
  42. sglang/srt/models/gemma.py +14 -24
  43. sglang/srt/models/grok.py +671 -0
  44. sglang/srt/models/llama2.py +24 -23
  45. sglang/srt/models/llava.py +85 -25
  46. sglang/srt/models/llavavid.py +298 -0
  47. sglang/srt/models/mixtral.py +254 -130
  48. sglang/srt/models/mixtral_quant.py +373 -0
  49. sglang/srt/models/qwen.py +28 -25
  50. sglang/srt/models/qwen2.py +17 -22
  51. sglang/srt/models/stablelm.py +21 -26
  52. sglang/srt/models/yivl.py +17 -25
  53. sglang/srt/openai_api_adapter.py +140 -95
  54. sglang/srt/openai_protocol.py +10 -1
  55. sglang/srt/server.py +101 -52
  56. sglang/srt/server_args.py +59 -11
  57. sglang/srt/utils.py +242 -75
  58. sglang/test/test_programs.py +44 -0
  59. sglang/test/test_utils.py +32 -1
  60. sglang/utils.py +95 -26
  61. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/METADATA +23 -13
  62. sglang-0.1.17.dist-info/RECORD +81 -0
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -402
  66. sglang-0.1.15.dist-info/RECORD +0 -69
  67. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
  69. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -8,6 +8,12 @@ from sglang.srt.utils import wrap_kernel_launcher
8
8
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
9
9
 
10
10
 
11
+ @triton.jit
12
+ def tanh(x):
13
+ # Tanh is just a scaled sigmoid
14
+ return 2 * tl.sigmoid(2 * x) - 1
15
+
16
+
11
17
  @triton.jit
12
18
  def _fwd_kernel(
13
19
  Q_Extend,
@@ -39,6 +45,7 @@ def _fwd_kernel(
39
45
  BLOCK_DMODEL: tl.constexpr,
40
46
  BLOCK_M: tl.constexpr,
41
47
  BLOCK_N: tl.constexpr,
48
+ logit_cap: tl.constexpr,
42
49
  ):
43
50
  cur_seq = tl.program_id(0)
44
51
  cur_head = tl.program_id(1)
@@ -90,6 +97,10 @@ def _fwd_kernel(
90
97
  qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
91
98
  qk += tl.dot(q, k)
92
99
  qk *= sm_scale
100
+
101
+ if logit_cap > 0:
102
+ qk = logit_cap * tanh(qk / logit_cap)
103
+
93
104
  qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
94
105
 
95
106
  n_e_max = tl.maximum(tl.max(qk, 1), e_max)
@@ -126,6 +137,10 @@ def _fwd_kernel(
126
137
  qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
127
138
  qk += tl.dot(q, k)
128
139
  qk *= sm_scale
140
+
141
+ if logit_cap > 0:
142
+ qk = logit_cap * tanh(qk / logit_cap)
143
+
129
144
  mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
130
145
  start_n + offs_n[None, :]
131
146
  )
@@ -176,6 +191,7 @@ def extend_attention_fwd(
176
191
  b_seq_len_extend,
177
192
  max_len_in_batch,
178
193
  max_len_extend,
194
+ logit_cap=-1,
179
195
  ):
180
196
  """
181
197
  q_extend, k_extend, v_extend, o_extend: contiguous tensors
@@ -271,6 +287,7 @@ def extend_attention_fwd(
271
287
  BLOCK_N=BLOCK_N,
272
288
  num_warps=num_warps,
273
289
  num_stages=num_stages,
290
+ logit_cap=logit_cap,
274
291
  )
275
292
  cached_kernel = wrap_kernel_launcher(_fwd_kernel)
276
293
 
@@ -0,0 +1,485 @@
1
+ # Adapted from
2
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/layers/fused_moe/fused_moe.py#L1
3
+ """Fused MoE kernel."""
4
+ import functools
5
+ import json
6
+ import os
7
+ from typing import Any, Dict, Optional, Tuple
8
+
9
+ import torch
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from vllm import _custom_ops as ops
14
+ from vllm.logger import init_logger
15
+ from vllm.utils import is_hip
16
+
17
+ logger = init_logger(__name__)
18
+
19
+
20
+ @triton.jit
21
+ def fused_moe_kernel(
22
+ # Pointers to matrices
23
+ a_ptr,
24
+ b_ptr,
25
+ c_ptr,
26
+ a_scale_ptr,
27
+ b_scale_ptr,
28
+ topk_weights_ptr,
29
+ sorted_token_ids_ptr,
30
+ expert_ids_ptr,
31
+ num_tokens_post_padded_ptr,
32
+ # Matrix dimensions
33
+ N,
34
+ K,
35
+ EM,
36
+ num_valid_tokens,
37
+ # The stride variables represent how much to increase the ptr by when
38
+ # moving by 1 element in a particular dimension. E.g. `stride_am` is
39
+ # how much to increase `a_ptr` by to get the element one row down
40
+ # (A has M rows).
41
+ stride_am,
42
+ stride_ak,
43
+ stride_be,
44
+ stride_bk,
45
+ stride_bn,
46
+ stride_cm,
47
+ stride_cn,
48
+ # Meta-parameters
49
+ BLOCK_SIZE_M: tl.constexpr,
50
+ BLOCK_SIZE_N: tl.constexpr,
51
+ BLOCK_SIZE_K: tl.constexpr,
52
+ GROUP_SIZE_M: tl.constexpr,
53
+ MUL_ROUTED_WEIGHT: tl.constexpr,
54
+ top_k: tl.constexpr,
55
+ compute_type: tl.constexpr,
56
+ use_fp8: tl.constexpr,
57
+ ):
58
+ """
59
+ Implements the fused computation for a Mixture of Experts (MOE) using
60
+ token and expert matrices.
61
+
62
+ Key Parameters:
63
+ - A: The input tensor representing tokens with shape (*, K), where '*' can
64
+ be any shape representing batches and K is the feature dimension of
65
+ each token.
66
+ - B: The stacked MOE weight tensor with shape (E, N, K), where E is
67
+ the number of experts, K is the input feature dimension, and N is
68
+ the output feature dimension.
69
+ - C: The output cache tensor with shape (M, topk, N), where M is the
70
+ total number of tokens post padding, topk is the number of times
71
+ each token is repeated, and N is the output feature dimension.
72
+ - sorted_token_ids: A tensor containing the sorted indices of tokens,
73
+ repeated topk times and arranged by the expert index they are
74
+ assigned to.
75
+ - expert_ids: A tensor containing the indices of the expert for each
76
+ block. It determines which expert matrix from B should be used for
77
+ each block in A.
78
+ This kernel performs the multiplication of a token by its corresponding
79
+ expert matrix as determined by `expert_ids`. The sorting of
80
+ `sorted_token_ids` by expert index and padding ensures divisibility by
81
+ BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
82
+ multiplication across different blocks processed by the same expert.
83
+ """
84
+ # -----------------------------------------------------------
85
+ # Map program ids `pid` to the block of C it should compute.
86
+ # This is done in a grouped ordering to promote L2 data reuse.
87
+ pid = tl.program_id(axis=0)
88
+ num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
89
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
90
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
91
+ group_id = pid // num_pid_in_group
92
+ first_pid_m = group_id * GROUP_SIZE_M
93
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
94
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
95
+ pid_n = (pid % num_pid_in_group) // group_size_m
96
+
97
+ # ----------------------------------------------------------
98
+ # Create pointers for the first blocks of A and B.
99
+ # We will advance this pointer as we move in the K direction
100
+ # and accumulate
101
+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
102
+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
103
+ num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
104
+ if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
105
+ return
106
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
107
+ offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
108
+ token_mask = offs_token < num_valid_tokens
109
+
110
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
111
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
112
+ a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
113
+ offs_k[None, :] * stride_ak)
114
+
115
+ off_experts = tl.load(expert_ids_ptr + pid_m)
116
+ b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
117
+ offs_bn[None, :] * stride_bn)
118
+
119
+ if use_fp8:
120
+ a_scale = tl.load(a_scale_ptr)
121
+ b_scale = tl.load(b_scale_ptr + off_experts)
122
+
123
+ # -----------------------------------------------------------
124
+ # Iterate to compute a block of the C matrix.
125
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
126
+ # of fp32 values for higher accuracy.
127
+ # `accumulator` will be converted back to fp16 after the loop.
128
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
129
+
130
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
131
+ # Load the next block of A and B, generate a mask by checking the
132
+ # K dimension.
133
+ a = tl.load(a_ptrs,
134
+ mask=token_mask[:, None] &
135
+ (offs_k[None, :] < K - k * BLOCK_SIZE_K),
136
+ other=0.0)
137
+ b = tl.load(b_ptrs,
138
+ mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
139
+ other=0.0)
140
+ # We accumulate along the K dimension.
141
+ if use_fp8:
142
+ accumulator = tl.dot(a, b, acc=accumulator)
143
+ else:
144
+ accumulator += tl.dot(a, b)
145
+ # Advance the ptrs to the next K block.
146
+ a_ptrs += BLOCK_SIZE_K * stride_ak
147
+ b_ptrs += BLOCK_SIZE_K * stride_bk
148
+
149
+ if MUL_ROUTED_WEIGHT:
150
+ moe_weight = tl.load(topk_weights_ptr + offs_token,
151
+ mask=token_mask,
152
+ other=0)
153
+ accumulator = accumulator * moe_weight[:, None]
154
+
155
+ if use_fp8:
156
+ accumulator = (accumulator * a_scale * b_scale).to(compute_type)
157
+ else:
158
+ accumulator = accumulator.to(compute_type)
159
+ # -----------------------------------------------------------
160
+ # Write back the block of the output
161
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
162
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
163
+ None, :]
164
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
165
+ tl.store(c_ptrs, accumulator, mask=c_mask)
166
+
167
+
168
+ def moe_align_block_size(
169
+ topk_ids: torch.Tensor, block_size: int,
170
+ num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
171
+ """
172
+ Aligns the token distribution across experts to be compatible with block
173
+ size for matrix multiplication.
174
+
175
+ Parameters:
176
+ - topk_ids: A tensor of shape [total_tokens, top_k] representing the
177
+ top-k expert indices for each token.
178
+ - block_size: The block size used in block matrix multiplication.
179
+ - num_experts: The total number of experts.
180
+
181
+ Returns:
182
+ - sorted_token_ids: A tensor containing the sorted token indices according
183
+ to their allocated expert.
184
+ - expert_ids: A tensor indicating the assigned expert index for each block.
185
+ - num_tokens_post_padded: The total number of tokens after padding,
186
+ ensuring divisibility by block_size.
187
+
188
+ This function pads the number of tokens that each expert needs to process
189
+ so that it is divisible by block_size.
190
+ Padding ensures that during block matrix multiplication, the dimensions
191
+ align correctly.
192
+
193
+ Example:
194
+ Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
195
+ block_size = 4, and num_experts = 4:
196
+ - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
197
+ with each expert needing to process 3 tokens.
198
+ - As block_size is 4, we pad 1 token for each expert.
199
+ - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
200
+ - Then append padding tokens [12, 12, 12, 12] for each block.
201
+ - After sorting by expert index, we obtain token_ids
202
+ [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
203
+ Tokens 12 are non-existent (padding) and are ignored in
204
+ the subsequent matrix multiplication.
205
+ - The padding ensures that the total number of tokens is now divisible
206
+ by block_size for proper block matrix operations.
207
+ """
208
+ max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
209
+ sorted_ids = torch.empty((max_num_tokens_padded, ),
210
+ dtype=torch.int32,
211
+ device=topk_ids.device)
212
+ sorted_ids.fill_(topk_ids.numel())
213
+ max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
214
+ expert_ids = torch.empty((max_num_m_blocks, ),
215
+ dtype=torch.int32,
216
+ device=topk_ids.device)
217
+ num_tokens_post_pad = torch.empty((1),
218
+ dtype=torch.int32,
219
+ device=topk_ids.device)
220
+ ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
221
+ expert_ids, num_tokens_post_pad)
222
+ return sorted_ids, expert_ids, num_tokens_post_pad
223
+
224
+
225
+ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
226
+ A_scale: Optional[torch.Tensor],
227
+ B_scale: Optional[torch.Tensor],
228
+ topk_weights: torch.Tensor, topk_ids: torch.Tensor,
229
+ sorted_token_ids: torch.Tensor,
230
+ expert_ids: torch.Tensor,
231
+ num_tokens_post_padded: torch.Tensor,
232
+ mul_routed_weight: bool, top_k: int,
233
+ config: Dict[str, Any], compute_type: tl.dtype,
234
+ use_fp8: bool) -> None:
235
+ assert topk_weights.stride(1) == 1
236
+ assert sorted_token_ids.stride(0) == 1
237
+
238
+ if not use_fp8:
239
+ assert A_scale is None
240
+ assert B_scale is None
241
+ else:
242
+ A, A_scale = ops.scaled_fp8_quant(A, A_scale)
243
+ assert B_scale is not None
244
+
245
+ grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
246
+ 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
247
+
248
+ fused_moe_kernel[grid](
249
+ A,
250
+ B,
251
+ C,
252
+ A_scale,
253
+ B_scale,
254
+ topk_weights,
255
+ sorted_token_ids,
256
+ expert_ids,
257
+ num_tokens_post_padded,
258
+ B.shape[1],
259
+ B.shape[2],
260
+ sorted_token_ids.shape[0],
261
+ topk_ids.numel(),
262
+ A.stride(0),
263
+ A.stride(1),
264
+ B.stride(0),
265
+ B.stride(2),
266
+ B.stride(1),
267
+ C.stride(1),
268
+ C.stride(2),
269
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
270
+ top_k=top_k,
271
+ compute_type=compute_type,
272
+ use_fp8=use_fp8,
273
+ **config,
274
+ )
275
+
276
+
277
+ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
278
+ device_name = torch.cuda.get_device_name().replace(" ", "_")
279
+ dtype_selector = "" if not dtype else f",dtype={dtype}"
280
+ return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
281
+
282
+
283
+ @functools.lru_cache
284
+ def get_moe_configs(E: int, N: int,
285
+ dtype: Optional[str]) -> Optional[Dict[int, Any]]:
286
+ """
287
+ Return optimized configurations for the fused MoE kernel.
288
+
289
+ The return value will be a dictionary that maps an irregular grid of
290
+ batch sizes to configurations of the fused_moe kernel. To evaluate the
291
+ kernel on a given batch size bs, the closest batch size in the grid should
292
+ be picked and the associated configuration chosen to invoke the kernel.
293
+ """
294
+
295
+ # First look up if an optimized configuration is available in the configs
296
+ # directory
297
+ json_file_name = get_config_file_name(E, N, dtype)
298
+
299
+ config_file_path = os.path.join(
300
+ os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
301
+ if os.path.exists(config_file_path):
302
+ with open(config_file_path) as f:
303
+ logger.info("Using configuration from %s for MoE layer.",
304
+ config_file_path)
305
+ # If a configuration has been found, return it
306
+ return {int(key): val for key, val in json.load(f).items()}
307
+
308
+ # If no optimized configuration is available, we will use the default
309
+ # configuration
310
+ return None
311
+
312
+
313
+ def fused_moe(
314
+ hidden_states: torch.Tensor,
315
+ w1: torch.Tensor,
316
+ w2: torch.Tensor,
317
+ gating_output: torch.Tensor,
318
+ topk: int,
319
+ renormalize: bool,
320
+ inplace: bool = False,
321
+ override_config: Optional[Dict[str, Any]] = None,
322
+ use_fp8: bool = False,
323
+ w1_scale: Optional[torch.Tensor] = None,
324
+ w2_scale: Optional[torch.Tensor] = None,
325
+ a1_scale: Optional[torch.Tensor] = None,
326
+ a2_scale: Optional[torch.Tensor] = None,
327
+ ) -> torch.Tensor:
328
+ """
329
+ This function computes a Mixture of Experts (MoE) layer using two sets of
330
+ weights, w1 and w2, and top-k gating mechanism.
331
+
332
+ Parameters:
333
+ - hidden_states (torch.Tensor): The input tensor to the MoE layer.
334
+ - w1 (torch.Tensor): The first set of expert weights.
335
+ - w2 (torch.Tensor): The second set of expert weights.
336
+ - gating_output (torch.Tensor): The output of the gating operation
337
+ (before softmax).
338
+ - topk (int): The number of top-k experts to select.
339
+ - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
340
+ - inplace (bool): If True, perform the operation in-place.
341
+ Defaults to False.
342
+ - override_config (Optional[Dict[str, Any]]): Optional override
343
+ for the kernel configuration.
344
+ - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
345
+ products for w1 and w2. Defaults to False.
346
+ - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
347
+ w1.
348
+ - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
349
+ w2.
350
+
351
+ Returns:
352
+ - torch.Tensor: The output tensor after applying the MoE layer.
353
+ """
354
+ # Check constraints.
355
+ assert hidden_states.shape[0] == gating_output.shape[0], (
356
+ "Number of tokens mismatch")
357
+ assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
358
+ assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
359
+ assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
360
+ assert w1.is_contiguous(), "Expert weights1 must be contiguous"
361
+ assert w2.is_contiguous(), "Expert weights2 must be contiguous"
362
+ assert hidden_states.dtype in [
363
+ torch.float32, torch.float16, torch.bfloat16
364
+ ]
365
+ M, _ = hidden_states.shape
366
+ E, N, _ = w1.shape
367
+
368
+ if is_hip():
369
+ # The MoE kernels are not yet supported on ROCm.
370
+ routing_weights = torch.softmax(gating_output,
371
+ dim=-1,
372
+ dtype=torch.float32)
373
+ topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
374
+ else:
375
+ import vllm._moe_C as moe_kernels
376
+
377
+ topk_weights = torch.empty(M,
378
+ topk,
379
+ dtype=torch.float32,
380
+ device=hidden_states.device)
381
+ topk_ids = torch.empty(M,
382
+ topk,
383
+ dtype=torch.int32,
384
+ device=hidden_states.device)
385
+ token_expert_indicies = torch.empty(M,
386
+ topk,
387
+ dtype=torch.int32,
388
+ device=hidden_states.device)
389
+ moe_kernels.topk_softmax(
390
+ topk_weights,
391
+ topk_ids,
392
+ token_expert_indicies,
393
+ gating_output.float(), # TODO(woosuk): Optimize this.
394
+ )
395
+ del token_expert_indicies # Not used. Will be used in the future.
396
+ if renormalize:
397
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
398
+
399
+ if override_config:
400
+ config = override_config
401
+ else:
402
+ # First try to load optimal config from the file
403
+ configs = get_moe_configs(E, w2.shape[2],
404
+ "float8" if use_fp8 else None)
405
+
406
+ if configs:
407
+ # If an optimal configuration map has been found, look up the
408
+ # optimal config
409
+ config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
410
+ else:
411
+ # Else use the default config
412
+ config = {
413
+ "BLOCK_SIZE_M": 128,
414
+ "BLOCK_SIZE_N": 64,
415
+ "BLOCK_SIZE_K": 128,
416
+ "GROUP_SIZE_M": 1,
417
+ "num_warps": 4,
418
+ "num_stages": 4
419
+ }
420
+
421
+ if M <= E:
422
+ config = {
423
+ "BLOCK_SIZE_M": 128,
424
+ "BLOCK_SIZE_N": 256,
425
+ "BLOCK_SIZE_K": 128,
426
+ "GROUP_SIZE_M": 16,
427
+ "num_warps": 8,
428
+ "num_stages": 4
429
+ }
430
+
431
+ intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
432
+ device=hidden_states.device,
433
+ dtype=hidden_states.dtype)
434
+ intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
435
+ device=hidden_states.device,
436
+ dtype=hidden_states.dtype)
437
+ intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
438
+ device=hidden_states.device,
439
+ dtype=hidden_states.dtype)
440
+
441
+ sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
442
+ topk_ids, config['BLOCK_SIZE_M'], E)
443
+ compute_type = (tl.bfloat16
444
+ if hidden_states.dtype == torch.bfloat16 else tl.float16)
445
+
446
+ invoke_fused_moe_kernel(hidden_states,
447
+ w1,
448
+ intermediate_cache1,
449
+ a1_scale,
450
+ w1_scale,
451
+ topk_weights,
452
+ topk_ids,
453
+ sorted_token_ids,
454
+ expert_ids,
455
+ num_tokens_post_padded,
456
+ False,
457
+ topk_ids.shape[1],
458
+ config,
459
+ compute_type=compute_type,
460
+ use_fp8=use_fp8)
461
+
462
+ ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
463
+
464
+ invoke_fused_moe_kernel(intermediate_cache2,
465
+ w2,
466
+ intermediate_cache3,
467
+ a2_scale,
468
+ w2_scale,
469
+ topk_weights,
470
+ topk_ids,
471
+ sorted_token_ids,
472
+ expert_ids,
473
+ num_tokens_post_padded,
474
+ True,
475
+ 1,
476
+ config,
477
+ compute_type=compute_type,
478
+ use_fp8=use_fp8)
479
+
480
+ if inplace:
481
+ return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
482
+ dim=1,
483
+ out=hidden_states)
484
+ return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
485
+ dim=1)
@@ -5,7 +5,7 @@ from vllm.distributed import (
5
5
  tensor_model_parallel_all_gather,
6
6
  )
7
7
 
8
- from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
8
+ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
9
9
 
10
10
 
11
11
  class LogitsProcessor(nn.Module):
@@ -50,21 +50,22 @@ class LogitsProcessor(nn.Module):
50
50
  prefill_top_logprobs, decode_top_logprobs = [], []
51
51
  pt = 0
52
52
  # NOTE: the GPU-CPU overhead can be reduced
53
- extend_seq_lens_cpu = input_metadata.extend_seq_lens.cpu().numpy()
54
- for i in range(len(extend_seq_lens_cpu)):
55
- if extend_seq_lens_cpu[i] == 0:
53
+ extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist()
54
+ for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
55
+ if extend_seq_len == 0:
56
56
  prefill_top_logprobs.append([])
57
57
  decode_top_logprobs.append([])
58
58
  continue
59
59
  k = input_metadata.top_logprobs_nums[i]
60
- t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k)
60
+ t = all_logprobs[pt : pt + extend_seq_len].topk(k)
61
61
  vs_cpu = t.values.tolist()
62
62
  ps_cpu = t.indices.tolist()
63
63
  prefill_top_logprobs.append(
64
64
  [list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
65
65
  )
66
66
  decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
67
- pt += extend_seq_lens_cpu[i]
67
+ pt += extend_seq_len
68
+
68
69
  return prefill_top_logprobs, decode_top_logprobs
69
70
 
70
71
  def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
@@ -145,7 +146,7 @@ class LogitsProcessor(nn.Module):
145
146
  )
146
147
 
147
148
 
148
- if __name__ == "__main__":
149
+ def test():
149
150
  all_logprobs = torch.tensor(
150
151
  # s s s
151
152
  [[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
@@ -173,3 +174,7 @@ if __name__ == "__main__":
173
174
  print("start", start)
174
175
  print("end", end)
175
176
  print("sum_logp", sum_logp)
177
+
178
+
179
+ if __name__ == "__main__":
180
+ test()
@@ -1,22 +1,26 @@
1
1
  import torch
2
+ import numpy as np
2
3
  from torch import nn
3
4
 
4
5
  from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
5
6
  from sglang.srt.layers.extend_attention import extend_attention_fwd
6
7
  from sglang.srt.layers.token_attention import token_attention_fwd
7
- from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
8
+ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
8
9
 
9
10
 
10
11
  class RadixAttention(nn.Module):
11
- def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id):
12
+ def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1):
12
13
  super().__init__()
13
14
  self.tp_q_head_num = num_heads
14
15
  self.tp_k_head_num = num_kv_heads
15
16
  self.tp_v_head_num = num_kv_heads
16
17
  self.head_dim = head_dim
17
18
  self.layer_id = layer_id
19
+ self.logit_cap = logit_cap
18
20
 
19
- from sglang.srt.managers.router.model_runner import global_server_args_dict
21
+ assert np.allclose(scaling, 1.0 / (head_dim**0.5))
22
+
23
+ from sglang.srt.managers.controller.model_runner import global_server_args_dict
20
24
 
21
25
  if global_server_args_dict.get("enable_flashinfer", False):
22
26
  self.prefill_forward = self.prefill_forward_flashinfer
@@ -38,6 +42,7 @@ class RadixAttention(nn.Module):
38
42
  input_metadata.start_loc,
39
43
  input_metadata.seq_lens,
40
44
  input_metadata.max_seq_len,
45
+ self.logit_cap,
41
46
  )
42
47
  self.store_kv_cache(k, v, input_metadata)
43
48
 
@@ -62,6 +67,7 @@ class RadixAttention(nn.Module):
62
67
  input_metadata.extend_seq_lens,
63
68
  input_metadata.max_seq_len,
64
69
  input_metadata.max_extend_len,
70
+ self.logit_cap,
65
71
  )
66
72
 
67
73
  return o
@@ -82,6 +88,7 @@ class RadixAttention(nn.Module):
82
88
  input_metadata.max_seq_len,
83
89
  input_metadata.other_kv_index,
84
90
  input_metadata.total_num_tokens,
91
+ self.logit_cap,
85
92
  )
86
93
 
87
94
  return o