sglang 0.1.18__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/api.py +26 -0
  3. sglang/backend/runtime_endpoint.py +18 -14
  4. sglang/bench_latency.py +34 -16
  5. sglang/global_config.py +1 -0
  6. sglang/lang/chat_template.py +41 -6
  7. sglang/lang/interpreter.py +5 -1
  8. sglang/lang/ir.py +61 -25
  9. sglang/srt/constrained/__init__.py +3 -2
  10. sglang/srt/hf_transformers_utils.py +7 -3
  11. sglang/srt/layers/extend_attention.py +2 -1
  12. sglang/srt/layers/fused_moe.py +181 -167
  13. sglang/srt/layers/logits_processor.py +55 -19
  14. sglang/srt/layers/radix_attention.py +24 -27
  15. sglang/srt/layers/token_attention.py +4 -1
  16. sglang/srt/managers/controller/infer_batch.py +2 -2
  17. sglang/srt/managers/controller/manager_single.py +1 -1
  18. sglang/srt/managers/controller/model_runner.py +27 -15
  19. sglang/srt/managers/controller/tp_worker.py +31 -14
  20. sglang/srt/managers/detokenizer_manager.py +4 -2
  21. sglang/srt/managers/io_struct.py +1 -1
  22. sglang/srt/managers/tokenizer_manager.py +14 -13
  23. sglang/srt/model_config.py +6 -0
  24. sglang/srt/models/gemma2.py +436 -0
  25. sglang/srt/models/llama2.py +3 -3
  26. sglang/srt/models/llama_classification.py +10 -7
  27. sglang/srt/models/minicpm.py +373 -0
  28. sglang/srt/models/qwen2_moe.py +454 -0
  29. sglang/srt/openai_api_adapter.py +2 -2
  30. sglang/srt/openai_protocol.py +1 -1
  31. sglang/srt/server.py +17 -8
  32. sglang/srt/server_args.py +14 -16
  33. sglang/srt/utils.py +68 -35
  34. {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/METADATA +19 -13
  35. {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/RECORD +38 -35
  36. {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
  37. {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/WHEEL +0 -0
  38. {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,6 @@ from typing import Any, Dict, Optional, Tuple
9
9
  import torch
10
10
  import triton
11
11
  import triton.language as tl
12
-
13
12
  from vllm import _custom_ops as ops
14
13
  from vllm.logger import init_logger
15
14
 
@@ -108,12 +107,16 @@ def fused_moe_kernel(
108
107
 
109
108
  offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
110
109
  offs_k = tl.arange(0, BLOCK_SIZE_K)
111
- a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
112
- offs_k[None, :] * stride_ak)
110
+ a_ptrs = a_ptr + (
111
+ offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
112
+ )
113
113
 
114
114
  off_experts = tl.load(expert_ids_ptr + pid_m)
115
- b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
116
- offs_bn[None, :] * stride_bn)
115
+ b_ptrs = (
116
+ b_ptr
117
+ + off_experts * stride_be
118
+ + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
119
+ )
117
120
 
118
121
  if use_fp8:
119
122
  a_scale = tl.load(a_scale_ptr)
@@ -129,13 +132,12 @@ def fused_moe_kernel(
129
132
  for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
130
133
  # Load the next block of A and B, generate a mask by checking the
131
134
  # K dimension.
132
- a = tl.load(a_ptrs,
133
- mask=token_mask[:, None] &
134
- (offs_k[None, :] < K - k * BLOCK_SIZE_K),
135
- other=0.0)
136
- b = tl.load(b_ptrs,
137
- mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
138
- other=0.0)
135
+ a = tl.load(
136
+ a_ptrs,
137
+ mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
138
+ other=0.0,
139
+ )
140
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
139
141
  # We accumulate along the K dimension.
140
142
  if use_fp8:
141
143
  accumulator = tl.dot(a, b, acc=accumulator)
@@ -146,9 +148,7 @@ def fused_moe_kernel(
146
148
  b_ptrs += BLOCK_SIZE_K * stride_bk
147
149
 
148
150
  if MUL_ROUTED_WEIGHT:
149
- moe_weight = tl.load(topk_weights_ptr + offs_token,
150
- mask=token_mask,
151
- other=0)
151
+ moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
152
152
  accumulator = accumulator * moe_weight[:, None]
153
153
 
154
154
  if use_fp8:
@@ -158,15 +158,14 @@ def fused_moe_kernel(
158
158
  # -----------------------------------------------------------
159
159
  # Write back the block of the output
160
160
  offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
161
- c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
162
- None, :]
161
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
163
162
  c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
164
163
  tl.store(c_ptrs, accumulator, mask=c_mask)
165
164
 
166
165
 
167
166
  def moe_align_block_size(
168
- topk_ids: torch.Tensor, block_size: int,
169
- num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
167
+ topk_ids: torch.Tensor, block_size: int, num_experts: int
168
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
170
169
  """
171
170
  Aligns the token distribution across experts to be compatible with block
172
171
  size for matrix multiplication.
@@ -205,32 +204,38 @@ def moe_align_block_size(
205
204
  by block_size for proper block matrix operations.
206
205
  """
207
206
  max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
208
- sorted_ids = torch.empty((max_num_tokens_padded, ),
209
- dtype=torch.int32,
210
- device=topk_ids.device)
207
+ sorted_ids = torch.empty(
208
+ (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
209
+ )
211
210
  sorted_ids.fill_(topk_ids.numel())
212
211
  max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
213
- expert_ids = torch.empty((max_num_m_blocks, ),
214
- dtype=torch.int32,
215
- device=topk_ids.device)
216
- num_tokens_post_pad = torch.empty((1),
217
- dtype=torch.int32,
218
- device=topk_ids.device)
219
- ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
220
- expert_ids, num_tokens_post_pad)
212
+ expert_ids = torch.empty(
213
+ (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
214
+ )
215
+ num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
216
+ ops.moe_align_block_size(
217
+ topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
218
+ )
221
219
  return sorted_ids, expert_ids, num_tokens_post_pad
222
220
 
223
221
 
224
- def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
225
- A_scale: Optional[torch.Tensor],
226
- B_scale: Optional[torch.Tensor],
227
- topk_weights: torch.Tensor, topk_ids: torch.Tensor,
228
- sorted_token_ids: torch.Tensor,
229
- expert_ids: torch.Tensor,
230
- num_tokens_post_padded: torch.Tensor,
231
- mul_routed_weight: bool, top_k: int,
232
- config: Dict[str, Any], compute_type: tl.dtype,
233
- use_fp8: bool) -> None:
222
+ def invoke_fused_moe_kernel(
223
+ A: torch.Tensor,
224
+ B: torch.Tensor,
225
+ C: torch.Tensor,
226
+ A_scale: Optional[torch.Tensor],
227
+ B_scale: Optional[torch.Tensor],
228
+ topk_weights: torch.Tensor,
229
+ topk_ids: torch.Tensor,
230
+ sorted_token_ids: torch.Tensor,
231
+ expert_ids: torch.Tensor,
232
+ num_tokens_post_padded: torch.Tensor,
233
+ mul_routed_weight: bool,
234
+ top_k: int,
235
+ config: Dict[str, Any],
236
+ compute_type: tl.dtype,
237
+ use_fp8: bool,
238
+ ) -> None:
234
239
  assert topk_weights.stride(1) == 1
235
240
  assert sorted_token_ids.stride(0) == 1
236
241
 
@@ -241,8 +246,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
241
246
  A, A_scale = ops.scaled_fp8_quant(A, A_scale)
242
247
  assert B_scale is not None
243
248
 
244
- grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
245
- 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
249
+ grid = lambda META: (
250
+ triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
251
+ * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
252
+ )
246
253
 
247
254
  fused_moe_kernel[grid](
248
255
  A,
@@ -280,8 +287,7 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
280
287
 
281
288
 
282
289
  @functools.lru_cache
283
- def get_moe_configs(E: int, N: int,
284
- dtype: Optional[str]) -> Optional[Dict[int, Any]]:
290
+ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
285
291
  """
286
292
  Return optimized configurations for the fused MoE kernel.
287
293
 
@@ -296,11 +302,11 @@ def get_moe_configs(E: int, N: int,
296
302
  json_file_name = get_config_file_name(E, N, dtype)
297
303
 
298
304
  config_file_path = os.path.join(
299
- os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
305
+ os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
306
+ )
300
307
  if os.path.exists(config_file_path):
301
308
  with open(config_file_path) as f:
302
- logger.info("Using configuration from %s for MoE layer.",
303
- config_file_path)
309
+ logger.info("Using configuration from %s for MoE layer.", config_file_path)
304
310
  # If a configuration has been found, return it
305
311
  return {int(key): val for key, val in json.load(f).items()}
306
312
 
@@ -319,35 +325,35 @@ def get_default_config(
319
325
  ) -> Dict[str, int]:
320
326
  if dtype == "float8":
321
327
  config = {
322
- 'BLOCK_SIZE_M': 128,
323
- 'BLOCK_SIZE_N': 256,
324
- 'BLOCK_SIZE_K': 128,
325
- 'GROUP_SIZE_M': 32,
328
+ "BLOCK_SIZE_M": 128,
329
+ "BLOCK_SIZE_N": 256,
330
+ "BLOCK_SIZE_K": 128,
331
+ "GROUP_SIZE_M": 32,
326
332
  "num_warps": 8,
327
- "num_stages": 4
333
+ "num_stages": 4,
328
334
  }
329
335
  if M <= E:
330
336
  config = {
331
- 'BLOCK_SIZE_M': 64,
332
- 'BLOCK_SIZE_N': 128,
333
- 'BLOCK_SIZE_K': 128,
334
- 'GROUP_SIZE_M': 1,
337
+ "BLOCK_SIZE_M": 64,
338
+ "BLOCK_SIZE_N": 128,
339
+ "BLOCK_SIZE_K": 128,
340
+ "GROUP_SIZE_M": 1,
335
341
  "num_warps": 4,
336
- "num_stages": 4
342
+ "num_stages": 4,
337
343
  }
338
344
  else:
339
345
  config = {
340
- 'BLOCK_SIZE_M': 64,
341
- 'BLOCK_SIZE_N': 64,
342
- 'BLOCK_SIZE_K': 32,
343
- 'GROUP_SIZE_M': 8
346
+ "BLOCK_SIZE_M": 64,
347
+ "BLOCK_SIZE_N": 64,
348
+ "BLOCK_SIZE_K": 32,
349
+ "GROUP_SIZE_M": 8,
344
350
  }
345
351
  if M <= E:
346
352
  config = {
347
- 'BLOCK_SIZE_M': 16,
348
- 'BLOCK_SIZE_N': 32,
349
- 'BLOCK_SIZE_K': 64,
350
- 'GROUP_SIZE_M': 1
353
+ "BLOCK_SIZE_M": 16,
354
+ "BLOCK_SIZE_N": 32,
355
+ "BLOCK_SIZE_K": 64,
356
+ "GROUP_SIZE_M": 1,
351
357
  }
352
358
  return config
353
359
 
@@ -358,23 +364,17 @@ def fused_topk(
358
364
  topk: int,
359
365
  renormalize: bool,
360
366
  ):
361
- assert hidden_states.shape[0] == gating_output.shape[0], (
362
- "Number of tokens mismatch")
367
+ assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
363
368
 
364
369
  M, _ = hidden_states.shape
365
370
 
366
- topk_weights = torch.empty(M,
367
- topk,
368
- dtype=torch.float32,
369
- device=hidden_states.device)
370
- topk_ids = torch.empty(M,
371
- topk,
372
- dtype=torch.int32,
373
- device=hidden_states.device)
374
- token_expert_indicies = torch.empty(M,
375
- topk,
376
- dtype=torch.int32,
377
- device=hidden_states.device)
371
+ topk_weights = torch.empty(
372
+ M, topk, dtype=torch.float32, device=hidden_states.device
373
+ )
374
+ topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
375
+ token_expert_indicies = torch.empty(
376
+ M, topk, dtype=torch.int32, device=hidden_states.device
377
+ )
378
378
  ops.topk_softmax(
379
379
  topk_weights,
380
380
  topk_ids,
@@ -388,27 +388,27 @@ def fused_topk(
388
388
  return topk_weights, topk_ids
389
389
 
390
390
 
391
- def fused_experts(hidden_states: torch.Tensor,
392
- w1: torch.Tensor,
393
- w2: torch.Tensor,
394
- topk_weights: torch.Tensor,
395
- topk_ids: torch.Tensor,
396
- inplace: bool = False,
397
- override_config: Optional[Dict[str, Any]] = None,
398
- use_fp8: bool = False,
399
- w1_scale: Optional[torch.Tensor] = None,
400
- w2_scale: Optional[torch.Tensor] = None,
401
- a1_scale: Optional[torch.Tensor] = None,
402
- a2_scale: Optional[torch.Tensor] = None):
391
+ def fused_experts(
392
+ hidden_states: torch.Tensor,
393
+ w1: torch.Tensor,
394
+ w2: torch.Tensor,
395
+ topk_weights: torch.Tensor,
396
+ topk_ids: torch.Tensor,
397
+ inplace: bool = False,
398
+ override_config: Optional[Dict[str, Any]] = None,
399
+ use_fp8: bool = False,
400
+ w1_scale: Optional[torch.Tensor] = None,
401
+ w2_scale: Optional[torch.Tensor] = None,
402
+ a1_scale: Optional[torch.Tensor] = None,
403
+ a2_scale: Optional[torch.Tensor] = None,
404
+ ):
403
405
  # Check constraints.
404
406
  assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
405
407
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
406
408
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
407
409
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
408
410
  assert w2.is_contiguous(), "Expert weights2 must be contiguous"
409
- assert hidden_states.dtype in [
410
- torch.float32, torch.float16, torch.bfloat16
411
- ]
411
+ assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
412
412
 
413
413
  M, _ = hidden_states.shape
414
414
  E, N, _ = w1.shape
@@ -417,8 +417,7 @@ def fused_experts(hidden_states: torch.Tensor,
417
417
  config = override_config
418
418
  else:
419
419
  # First try to load optimal config from the file
420
- configs = get_moe_configs(E, w2.shape[2],
421
- "float8" if use_fp8 else None)
420
+ configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
422
421
 
423
422
  if configs:
424
423
  # If an optimal configuration map has been found, look up the
@@ -426,65 +425,76 @@ def fused_experts(hidden_states: torch.Tensor,
426
425
  config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
427
426
  else:
428
427
  # Else use the default config
429
- config = get_default_config(M, E, N, w1.shape[2],
430
- topk_ids.shape[1],
431
- "float8" if use_fp8 else None)
432
-
433
- intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
434
- device=hidden_states.device,
435
- dtype=hidden_states.dtype)
436
- intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
437
- device=hidden_states.device,
438
- dtype=hidden_states.dtype)
439
- intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
440
- device=hidden_states.device,
441
- dtype=hidden_states.dtype)
428
+ config = get_default_config(
429
+ M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None
430
+ )
431
+
432
+ intermediate_cache1 = torch.empty(
433
+ (M, topk_ids.shape[1], N),
434
+ device=hidden_states.device,
435
+ dtype=hidden_states.dtype,
436
+ )
437
+ intermediate_cache2 = torch.empty(
438
+ (M * topk_ids.shape[1], N // 2),
439
+ device=hidden_states.device,
440
+ dtype=hidden_states.dtype,
441
+ )
442
+ intermediate_cache3 = torch.empty(
443
+ (M, topk_ids.shape[1], w2.shape[1]),
444
+ device=hidden_states.device,
445
+ dtype=hidden_states.dtype,
446
+ )
442
447
 
443
448
  sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
444
- topk_ids, config['BLOCK_SIZE_M'], E)
445
- compute_type = (tl.bfloat16
446
- if hidden_states.dtype == torch.bfloat16 else tl.float16)
447
-
448
- invoke_fused_moe_kernel(hidden_states,
449
- w1,
450
- intermediate_cache1,
451
- a1_scale,
452
- w1_scale,
453
- topk_weights,
454
- topk_ids,
455
- sorted_token_ids,
456
- expert_ids,
457
- num_tokens_post_padded,
458
- False,
459
- topk_ids.shape[1],
460
- config,
461
- compute_type=compute_type,
462
- use_fp8=use_fp8)
449
+ topk_ids, config["BLOCK_SIZE_M"], E
450
+ )
451
+ compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
452
+
453
+ invoke_fused_moe_kernel(
454
+ hidden_states,
455
+ w1,
456
+ intermediate_cache1,
457
+ a1_scale,
458
+ w1_scale,
459
+ topk_weights,
460
+ topk_ids,
461
+ sorted_token_ids,
462
+ expert_ids,
463
+ num_tokens_post_padded,
464
+ False,
465
+ topk_ids.shape[1],
466
+ config,
467
+ compute_type=compute_type,
468
+ use_fp8=use_fp8,
469
+ )
463
470
 
464
471
  ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
465
472
 
466
- invoke_fused_moe_kernel(intermediate_cache2,
467
- w2,
468
- intermediate_cache3,
469
- a2_scale,
470
- w2_scale,
471
- topk_weights,
472
- topk_ids,
473
- sorted_token_ids,
474
- expert_ids,
475
- num_tokens_post_padded,
476
- True,
477
- 1,
478
- config,
479
- compute_type=compute_type,
480
- use_fp8=use_fp8)
473
+ invoke_fused_moe_kernel(
474
+ intermediate_cache2,
475
+ w2,
476
+ intermediate_cache3,
477
+ a2_scale,
478
+ w2_scale,
479
+ topk_weights,
480
+ topk_ids,
481
+ sorted_token_ids,
482
+ expert_ids,
483
+ num_tokens_post_padded,
484
+ True,
485
+ 1,
486
+ config,
487
+ compute_type=compute_type,
488
+ use_fp8=use_fp8,
489
+ )
481
490
 
482
491
  if inplace:
483
- return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
484
- dim=1,
485
- out=hidden_states)
486
- return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
487
- dim=1)
492
+ return torch.sum(
493
+ intermediate_cache3.view(*intermediate_cache3.shape),
494
+ dim=1,
495
+ out=hidden_states,
496
+ )
497
+ return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
488
498
 
489
499
 
490
500
  def fused_moe(
@@ -532,25 +542,28 @@ def fused_moe(
532
542
  assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
533
543
 
534
544
  if hasattr(ops, "topk_softmax"):
535
- topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
536
- renormalize)
545
+ topk_weights, topk_ids = fused_topk(
546
+ hidden_states, gating_output, topk, renormalize
547
+ )
537
548
  else:
538
- topk_weights, topk_ids = fused_topk_v0_4_3(hidden_states, gating_output, topk,
539
- renormalize)
540
-
541
- return fused_experts(hidden_states,
542
- w1,
543
- w2,
544
- topk_weights,
545
- topk_ids,
546
- inplace=inplace,
547
- override_config=override_config,
548
- use_fp8=use_fp8,
549
- w1_scale=w1_scale,
550
- w2_scale=w2_scale,
551
- a1_scale=a1_scale,
552
- a2_scale=a2_scale)
553
-
549
+ topk_weights, topk_ids = fused_topk_v0_4_3(
550
+ hidden_states, gating_output, topk, renormalize
551
+ )
552
+
553
+ return fused_experts(
554
+ hidden_states,
555
+ w1,
556
+ w2,
557
+ topk_weights,
558
+ topk_ids,
559
+ inplace=inplace,
560
+ override_config=override_config,
561
+ use_fp8=use_fp8,
562
+ w1_scale=w1_scale,
563
+ w2_scale=w2_scale,
564
+ a1_scale=a1_scale,
565
+ a2_scale=a2_scale,
566
+ )
554
567
 
555
568
 
556
569
  def fused_topk_v0_4_3(
@@ -560,6 +573,7 @@ def fused_topk_v0_4_3(
560
573
  renormalize: bool,
561
574
  ):
562
575
  import vllm._moe_C as moe_kernels
576
+
563
577
  M, _ = hidden_states.shape
564
578
 
565
579
  topk_weights = torch.empty(
@@ -579,4 +593,4 @@ def fused_topk_v0_4_3(
579
593
  if renormalize:
580
594
  topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
581
595
 
582
- return topk_weights, topk_ids
596
+ return topk_weights, topk_ids
@@ -1,7 +1,7 @@
1
1
  """Logits processing."""
2
2
 
3
3
  import dataclasses
4
- from typing import List
4
+ from typing import List, Union
5
5
 
6
6
  import torch
7
7
  from torch import nn
@@ -31,6 +31,27 @@ class LogitProcessorOutput:
31
31
  decode_top_logprobs: List
32
32
 
33
33
 
34
+ @dataclasses.dataclass
35
+ class LogitsMetadata:
36
+ forward_mode: ForwardMode
37
+ extend_seq_lens: torch.Tensor
38
+ extend_start_loc: torch.Tensor
39
+
40
+ # For logprobs
41
+ return_logprob: bool
42
+ top_logprobs_nums: List[int]
43
+
44
+ @classmethod
45
+ def from_input_metadata(cls, input_metadata: InputMetadata):
46
+ return cls(
47
+ forward_mode=input_metadata.forward_mode,
48
+ extend_seq_lens=input_metadata.extend_seq_lens,
49
+ extend_start_loc=input_metadata.extend_start_loc,
50
+ return_logprob=input_metadata.return_logprob,
51
+ top_logprobs_nums=input_metadata.top_logprobs_nums,
52
+ )
53
+
54
+
34
55
  class LogitsProcessor(nn.Module):
35
56
  def __init__(self, config):
36
57
  super().__init__()
@@ -38,14 +59,14 @@ class LogitsProcessor(nn.Module):
38
59
  self.tp_size = get_tensor_model_parallel_world_size()
39
60
 
40
61
  def _get_normalized_prompt_logprobs(
41
- self, prefill_token_logprobs, input_metadata: InputMetadata
62
+ self, prefill_token_logprobs, logits_metadata: LogitsMetadata
42
63
  ):
43
64
  logprobs_cumsum = torch.cumsum(
44
65
  prefill_token_logprobs, dim=0, dtype=torch.float32
45
66
  )
46
67
 
47
- start = input_metadata.extend_start_loc.clone()
48
- end = start + input_metadata.extend_seq_lens - 2
68
+ start = logits_metadata.extend_start_loc.clone()
69
+ end = start + logits_metadata.extend_seq_lens - 2
49
70
  start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
50
71
  end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
51
72
  sum_logp = (
@@ -54,17 +75,17 @@ class LogitsProcessor(nn.Module):
54
75
  + prefill_token_logprobs[start]
55
76
  )
56
77
  normalized_prompt_logprobs = sum_logp / (
57
- (input_metadata.extend_seq_lens - 1).clamp(min=1)
78
+ (logits_metadata.extend_seq_lens - 1).clamp(min=1)
58
79
  )
59
80
 
60
81
  return normalized_prompt_logprobs
61
82
 
62
- def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata):
83
+ def _get_top_logprobs(self, all_logprobs, logits_metadata: LogitsMetadata):
63
84
  # TODO: vectorize the code below
64
- if input_metadata.forward_mode == ForwardMode.DECODE:
85
+ if logits_metadata.forward_mode == ForwardMode.DECODE:
65
86
  decode_top_logprobs = []
66
87
  for i in range(all_logprobs.shape[0]):
67
- k = input_metadata.top_logprobs_nums[i]
88
+ k = logits_metadata.top_logprobs_nums[i]
68
89
  t = all_logprobs[i].topk(k)
69
90
  v_cpu = t.values.tolist()
70
91
  p_cpu = t.indices.tolist()
@@ -73,13 +94,13 @@ class LogitsProcessor(nn.Module):
73
94
  else:
74
95
  prefill_top_logprobs, decode_top_logprobs = [], []
75
96
  pt = 0
76
- extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist()
97
+ extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
77
98
  for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
78
99
  if extend_seq_len == 0:
79
100
  prefill_top_logprobs.append([])
80
101
  decode_top_logprobs.append([])
81
102
  continue
82
- k = input_metadata.top_logprobs_nums[i]
103
+ k = logits_metadata.top_logprobs_nums[i]
83
104
  t = all_logprobs[pt : pt + extend_seq_len].topk(k)
84
105
  vs_cpu = t.values.tolist()
85
106
  ps_cpu = t.indices.tolist()
@@ -91,14 +112,24 @@ class LogitsProcessor(nn.Module):
91
112
 
92
113
  return prefill_top_logprobs, decode_top_logprobs
93
114
 
94
- def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
115
+ def forward(
116
+ self,
117
+ input_ids,
118
+ hidden_states,
119
+ weight,
120
+ logits_metadata: Union[LogitsMetadata, InputMetadata],
121
+ ):
122
+ if isinstance(logits_metadata, InputMetadata):
123
+ logits_metadata = LogitsMetadata.from_input_metadata(logits_metadata)
124
+ assert isinstance(logits_metadata, LogitsMetadata)
125
+
95
126
  # Get the last hidden states and last logits for the next token prediction
96
- if input_metadata.forward_mode == ForwardMode.DECODE:
127
+ if logits_metadata.forward_mode == ForwardMode.DECODE:
97
128
  last_index = None
98
129
  last_hidden = hidden_states
99
130
  else:
100
131
  last_index = (
101
- torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long)
132
+ torch.cumsum(logits_metadata.extend_seq_lens, dim=0, dtype=torch.long)
102
133
  - 1
103
134
  )
104
135
  last_hidden = hidden_states[last_index]
@@ -108,8 +139,13 @@ class LogitsProcessor(nn.Module):
108
139
  last_logits = tensor_model_parallel_all_gather(last_logits)
109
140
  last_logits = last_logits[:, : self.config.vocab_size]
110
141
 
142
+ if hasattr(self.config, "final_logit_softcapping"):
143
+ last_logits /= self.config.final_logit_softcapping
144
+ last_logits = torch.tanh(last_logits)
145
+ last_logits *= self.config.final_logit_softcapping
146
+
111
147
  # Return only last_logits if logprob is not requested
112
- if not input_metadata.return_logprob:
148
+ if not logits_metadata.return_logprob:
113
149
  return LogitProcessorOutput(
114
150
  next_token_logits=last_logits,
115
151
  next_token_logprobs=None,
@@ -120,7 +156,7 @@ class LogitsProcessor(nn.Module):
120
156
  )
121
157
  else:
122
158
  # When logprob is requested, compute the logits for all tokens.
123
- if input_metadata.forward_mode == ForwardMode.DECODE:
159
+ if logits_metadata.forward_mode == ForwardMode.DECODE:
124
160
  all_logits = last_logits
125
161
  else:
126
162
  all_logits = torch.matmul(hidden_states, weight.T)
@@ -133,15 +169,15 @@ class LogitsProcessor(nn.Module):
133
169
  all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
134
170
 
135
171
  # Get the logprob of top-k tokens
136
- return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
172
+ return_top_logprob = any(x > 0 for x in logits_metadata.top_logprobs_nums)
137
173
  if return_top_logprob:
138
174
  prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
139
- all_logprobs, input_metadata
175
+ all_logprobs, logits_metadata
140
176
  )
141
177
  else:
142
178
  prefill_top_logprobs = decode_top_logprobs = None
143
179
 
144
- if input_metadata.forward_mode == ForwardMode.DECODE:
180
+ if logits_metadata.forward_mode == ForwardMode.DECODE:
145
181
  return LogitProcessorOutput(
146
182
  next_token_logits=last_logits,
147
183
  next_token_logprobs=all_logprobs,
@@ -161,7 +197,7 @@ class LogitsProcessor(nn.Module):
161
197
  ]
162
198
 
163
199
  normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
164
- prefill_token_logprobs, input_metadata
200
+ prefill_token_logprobs, logits_metadata
165
201
  )
166
202
 
167
203
  return LogitProcessorOutput(