sglang 0.3.0__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_latency.py +17 -8
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +5 -17
  4. sglang/lang/backend/runtime_endpoint.py +5 -2
  5. sglang/lang/interpreter.py +1 -4
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +33 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +1 -3
  14. sglang/srt/layers/activation.py +12 -0
  15. sglang/srt/layers/attention_backend.py +480 -0
  16. sglang/srt/layers/flashinfer_utils.py +235 -0
  17. sglang/srt/layers/fused_moe/layer.py +27 -7
  18. sglang/srt/layers/layernorm.py +12 -0
  19. sglang/srt/layers/logits_processor.py +64 -77
  20. sglang/srt/layers/radix_attention.py +11 -161
  21. sglang/srt/layers/sampler.py +38 -122
  22. sglang/srt/layers/torchao_utils.py +75 -0
  23. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  24. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  25. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  26. sglang/srt/lora/lora.py +403 -0
  27. sglang/srt/lora/lora_config.py +43 -0
  28. sglang/srt/lora/lora_manager.py +259 -0
  29. sglang/srt/managers/controller_multi.py +1 -5
  30. sglang/srt/managers/controller_single.py +0 -5
  31. sglang/srt/managers/io_struct.py +16 -1
  32. sglang/srt/managers/policy_scheduler.py +122 -5
  33. sglang/srt/managers/schedule_batch.py +105 -71
  34. sglang/srt/managers/tokenizer_manager.py +17 -8
  35. sglang/srt/managers/tp_worker.py +188 -121
  36. sglang/srt/model_executor/cuda_graph_runner.py +69 -133
  37. sglang/srt/model_executor/forward_batch_info.py +35 -312
  38. sglang/srt/model_executor/model_runner.py +123 -154
  39. sglang/srt/models/baichuan.py +416 -0
  40. sglang/srt/models/chatglm.py +1 -5
  41. sglang/srt/models/commandr.py +1 -5
  42. sglang/srt/models/dbrx.py +1 -5
  43. sglang/srt/models/deepseek.py +1 -5
  44. sglang/srt/models/deepseek_v2.py +7 -6
  45. sglang/srt/models/exaone.py +1 -5
  46. sglang/srt/models/gemma.py +1 -5
  47. sglang/srt/models/gemma2.py +1 -5
  48. sglang/srt/models/gpt_bigcode.py +1 -5
  49. sglang/srt/models/grok.py +1 -5
  50. sglang/srt/models/internlm2.py +1 -5
  51. sglang/srt/models/llama.py +51 -5
  52. sglang/srt/models/llama_classification.py +1 -20
  53. sglang/srt/models/llava.py +30 -5
  54. sglang/srt/models/llavavid.py +2 -2
  55. sglang/srt/models/minicpm.py +1 -5
  56. sglang/srt/models/minicpm3.py +669 -0
  57. sglang/srt/models/mixtral.py +6 -5
  58. sglang/srt/models/mixtral_quant.py +1 -5
  59. sglang/srt/models/olmoe.py +415 -0
  60. sglang/srt/models/qwen.py +1 -5
  61. sglang/srt/models/qwen2.py +1 -5
  62. sglang/srt/models/qwen2_moe.py +6 -5
  63. sglang/srt/models/stablelm.py +1 -5
  64. sglang/srt/models/xverse.py +375 -0
  65. sglang/srt/models/xverse_moe.py +445 -0
  66. sglang/srt/openai_api/adapter.py +65 -46
  67. sglang/srt/openai_api/protocol.py +11 -3
  68. sglang/srt/sampling/sampling_batch_info.py +46 -80
  69. sglang/srt/server.py +30 -15
  70. sglang/srt/server_args.py +163 -28
  71. sglang/srt/utils.py +19 -51
  72. sglang/test/few_shot_gsm8k.py +132 -0
  73. sglang/test/runners.py +114 -22
  74. sglang/test/test_programs.py +7 -5
  75. sglang/test/test_utils.py +85 -2
  76. sglang/utils.py +32 -37
  77. sglang/version.py +1 -1
  78. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/METADATA +30 -18
  79. sglang-0.3.1.post1.dist-info/RECORD +130 -0
  80. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/WHEEL +1 -1
  81. sglang-0.3.0.dist-info/RECORD +0 -118
  82. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/LICENSE +0 -0
  83. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,235 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def create_flashinfer_kv_indices_triton(
8
+ req_to_token_ptr, # [max_batch, max_context_len]
9
+ req_pool_indices_ptr,
10
+ page_kernel_lens_ptr,
11
+ kv_indptr,
12
+ kv_start_idx,
13
+ kv_indices_ptr,
14
+ max_context_len: tl.constexpr,
15
+ ):
16
+ BLOCK_SIZE: tl.constexpr = 512
17
+ pid = tl.program_id(axis=0)
18
+ req_pool_index = tl.load(req_pool_indices_ptr + pid)
19
+ kv_indices_offset = tl.load(kv_indptr + pid)
20
+
21
+ kv_start = 0
22
+ kv_end = 0
23
+ if kv_start_idx:
24
+ kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
25
+ kv_end = kv_start
26
+ kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
27
+
28
+ req_to_token_ptr += req_pool_index * max_context_len
29
+ kv_indices_ptr += kv_indices_offset
30
+
31
+ ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
32
+ st_offset = tl.arange(0, BLOCK_SIZE)
33
+ num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
34
+ for _ in range(num_loop):
35
+ mask = ld_offset < kv_end
36
+ data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
37
+ tl.store(kv_indices_ptr + st_offset, data, mask=mask)
38
+ ld_offset += BLOCK_SIZE
39
+ st_offset += BLOCK_SIZE
40
+
41
+
42
+ class FlashinferUpdater:
43
+ def __init__(
44
+ self,
45
+ forward_mode,
46
+ model_runner,
47
+ req_pool_indices,
48
+ seq_lens,
49
+ prefix_lens,
50
+ decode_wrapper=None,
51
+ use_ragged=False,
52
+ ):
53
+ self.forward_mode = forward_mode
54
+ self.model_runner = model_runner
55
+ self.req_pool_indices = req_pool_indices
56
+ self.seq_lens = seq_lens
57
+ self.prefix_lens = prefix_lens
58
+ self.use_ragged = use_ragged
59
+
60
+ self.num_qo_heads = (
61
+ model_runner.model_config.num_attention_heads // model_runner.tp_size
62
+ )
63
+ self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
64
+ model_runner.tp_size
65
+ )
66
+ self.head_dim = model_runner.model_config.head_dim
67
+ self.batch_size = len(req_pool_indices)
68
+
69
+ self.decode_wrapper = (
70
+ decode_wrapper or self.model_runner.attn_backend.decode_wrapper
71
+ )
72
+ self.prefill_wrapper_ragged = (
73
+ self.model_runner.attn_backend.prefill_wrapper_ragged
74
+ )
75
+ self.prefill_wrapper_paged = (
76
+ self.model_runner.attn_backend.prefill_wrapper_paged
77
+ )
78
+
79
+ self.kv_last_page_len = torch.ones(
80
+ (self.batch_size,), dtype=torch.int32, device="cuda"
81
+ )
82
+
83
+ def _init_indices_no_sliding_window(self):
84
+ if self.use_ragged:
85
+ paged_kernel_lens = self.prefix_lens
86
+ else:
87
+ paged_kernel_lens = self.seq_lens
88
+
89
+ self.kv_indptr = torch.zeros(
90
+ (self.batch_size + 1,), dtype=torch.int32, device="cuda"
91
+ )
92
+ self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
93
+ self.kv_indices = torch.empty(
94
+ self.kv_indptr[-1], dtype=torch.int32, device="cuda"
95
+ )
96
+
97
+ create_flashinfer_kv_indices_triton[(self.batch_size,)](
98
+ self.model_runner.req_to_token_pool.req_to_token,
99
+ self.req_pool_indices,
100
+ paged_kernel_lens,
101
+ self.kv_indptr,
102
+ None,
103
+ self.kv_indices,
104
+ self.model_runner.req_to_token_pool.req_to_token.size(1),
105
+ )
106
+
107
+ def _init_indices_sliding_window(self, wrapper_id):
108
+ if wrapper_id == 0:
109
+ # window attention use paged only
110
+ if self.forward_mode.is_decode():
111
+ paged_kernel_lens = torch.minimum(
112
+ self.seq_lens,
113
+ torch.tensor(self.model_runner.sliding_window_size + 1),
114
+ )
115
+ else:
116
+ paged_kernel_lens = torch.minimum(
117
+ self.seq_lens,
118
+ torch.tensor(self.model_runner.sliding_window_size)
119
+ + self.seq_lens
120
+ - self.prefix_lens,
121
+ )
122
+ else:
123
+ # full attention
124
+ paged_kernel_lens = self.seq_lens
125
+
126
+ kv_start_idx = self.seq_lens - paged_kernel_lens
127
+ self.kv_indptr = torch.zeros(
128
+ (self.batch_size + 1,), dtype=torch.int32, device="cuda"
129
+ )
130
+ self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
131
+ self.kv_indices = torch.empty(
132
+ self.kv_indptr[-1], dtype=torch.int32, device="cuda"
133
+ )
134
+ create_flashinfer_kv_indices_triton[(self.batch_size,)](
135
+ self.model_runner.req_to_token_pool.req_to_token,
136
+ self.req_pool_indices,
137
+ paged_kernel_lens,
138
+ self.kv_indptr,
139
+ kv_start_idx,
140
+ self.kv_indices,
141
+ self.model_runner.req_to_token_pool.req_to_token.size(1),
142
+ )
143
+
144
+ def _update_decode_indices(self, decode_wrapper):
145
+ decode_wrapper.end_forward()
146
+ decode_wrapper.begin_forward(
147
+ self.kv_indptr,
148
+ self.kv_indices,
149
+ self.kv_last_page_len,
150
+ self.num_qo_heads,
151
+ self.num_kv_heads,
152
+ self.head_dim,
153
+ 1,
154
+ data_type=self.model_runner.kv_cache_dtype,
155
+ q_data_type=self.model_runner.dtype,
156
+ )
157
+
158
+ def _update_extend_indices(self, ragged_wrapper, paged_wrapper):
159
+ # extend part
160
+ qo_indptr = torch.zeros(
161
+ (self.batch_size + 1,), dtype=torch.int32, device="cuda"
162
+ )
163
+ qo_indptr[1:] = torch.cumsum(self.seq_lens - self.prefix_lens, dim=0)
164
+
165
+ if self.use_ragged:
166
+ ragged_wrapper.end_forward()
167
+ ragged_wrapper.begin_forward(
168
+ qo_indptr,
169
+ qo_indptr,
170
+ self.num_qo_heads,
171
+ self.num_kv_heads,
172
+ self.head_dim,
173
+ )
174
+
175
+ # cached part
176
+ paged_wrapper.end_forward()
177
+ paged_wrapper.begin_forward(
178
+ qo_indptr,
179
+ self.kv_indptr,
180
+ self.kv_indices,
181
+ self.kv_last_page_len,
182
+ self.num_qo_heads,
183
+ self.num_kv_heads,
184
+ self.head_dim,
185
+ 1,
186
+ )
187
+
188
+ def update_indices_no_sliding_window(self):
189
+ self._init_indices_no_sliding_window()
190
+
191
+ if self.forward_mode.is_decode():
192
+ self._update_decode_indices(self.decode_wrapper)
193
+ else:
194
+ self._update_extend_indices(
195
+ self.prefill_wrapper_ragged,
196
+ self.prefill_wrapper_paged,
197
+ )
198
+
199
+ def update_indices_sliding_window(self):
200
+ assert self.use_ragged is False
201
+
202
+ for wrapper_id in range(2):
203
+ self._init_indices_sliding_window(wrapper_id)
204
+ if self.forward_mode.is_decode():
205
+ self._update_decode_indices(self.decode_wrapper[wrapper_id])
206
+ else:
207
+ self._update_extend_indices(
208
+ None,
209
+ self.prefill_wrapper_paged[wrapper_id],
210
+ )
211
+
212
+
213
+ def update_flashinfer_indices(
214
+ forward_mode,
215
+ model_runner,
216
+ req_pool_indices,
217
+ seq_lens,
218
+ prefix_lens,
219
+ decode_wrapper=None,
220
+ use_ragged=False,
221
+ ):
222
+ updater = FlashinferUpdater(
223
+ forward_mode,
224
+ model_runner,
225
+ req_pool_indices,
226
+ seq_lens,
227
+ prefix_lens,
228
+ decode_wrapper,
229
+ use_ragged,
230
+ )
231
+
232
+ if model_runner.sliding_window_size is None:
233
+ updater.update_indices_no_sliding_window()
234
+ else:
235
+ updater.update_indices_sliding_window()
@@ -18,6 +18,8 @@ from vllm.model_executor.layers.quantization.base_config import (
18
18
  from vllm.model_executor.layers.quantization.fp8 import Fp8Config
19
19
  from vllm.model_executor.utils import set_weight_attrs
20
20
 
21
+ from sglang.srt.utils import is_hip
22
+
21
23
  logger = init_logger(__name__)
22
24
 
23
25
 
@@ -381,6 +383,7 @@ from torch.nn import Module
381
383
  from vllm import _custom_ops as ops
382
384
  from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
383
385
  all_close_1d,
386
+ normalize_e4m3fn_to_e4m3fnuz,
384
387
  per_tensor_dequantize,
385
388
  )
386
389
  from vllm.utils import print_warning_once
@@ -479,14 +482,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
479
482
 
480
483
  def process_weights_after_loading(self, layer: Module) -> None:
481
484
 
482
- # If checkpoint is fp16, quantize in place.
485
+ # If checkpoint is fp16 or bfloat16, quantize in place.
483
486
  if not self.quant_config.is_checkpoint_fp8_serialized:
484
- w13_weight = torch.empty_like(
485
- layer.w13_weight.data, dtype=torch.float8_e4m3fn
486
- )
487
- w2_weight = torch.empty_like(
488
- layer.w2_weight.data, dtype=torch.float8_e4m3fn
489
- )
487
+ # If ROCm, use float8_e4m3fnuz instead (MI300x HW)
488
+ fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
489
+ w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
490
+ w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
490
491
 
491
492
  # Re-initialize w13_scale because we directly quantize
492
493
  # merged w13 weights and generate a single scaling factor.
@@ -534,6 +535,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
534
535
  layer.a2_scale.max(), requires_grad=False
535
536
  )
536
537
 
538
+ # If ROCm, normalize the weights and scales to e4m3fnuz
539
+ if is_hip():
540
+ # Normalize the weights and scales
541
+ w13_weight, w13_scale, a13_scale = normalize_e4m3fn_to_e4m3fnuz(
542
+ layer.w13_weight, layer.w13_scale, layer.a13_scale
543
+ )
544
+ w2_weight, w2_scale, a2_scale = normalize_e4m3fn_to_e4m3fnuz(
545
+ layer.w2_weight, layer.w2_scale, layer.a2_scale
546
+ )
547
+ # Reset the parameters
548
+ layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
549
+ layer.w13_scale = torch.nn.Parameter(w13_scale, requires_grad=False)
550
+ if a13_scale is not None:
551
+ layer.a13_scale = torch.nn.Parameter(a13_scale, requires_grad=False)
552
+ layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
553
+ layer.w2_scale = torch.nn.Parameter(w2_scale, requires_grad=False)
554
+ if a2_scale is not None:
555
+ layer.a2_scale = torch.nn.Parameter(a2_scale, requires_grad=False)
556
+
537
557
  # Fp8 moe kernel needs single weight scale for w13 per expert.
538
558
  # We take the max then dequant and requant each expert.
539
559
  assert layer.w13_scale is not None
@@ -15,6 +15,7 @@ limitations under the License.
15
15
 
16
16
  """Fused operators for normalization layers."""
17
17
 
18
+ import logging
18
19
  from typing import Optional, Tuple, Union
19
20
 
20
21
  import torch
@@ -27,6 +28,10 @@ from flashinfer.norm import (
27
28
  )
28
29
  from vllm.model_executor.custom_op import CustomOp
29
30
 
31
+ from sglang.srt.utils import is_hip
32
+
33
+ logger = logging.getLogger(__name__)
34
+
30
35
 
31
36
  class RMSNorm(CustomOp):
32
37
  def __init__(
@@ -109,3 +114,10 @@ class GemmaRMSNorm(CustomOp):
109
114
  return x, residual
110
115
  out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
111
116
  return out
117
+
118
+
119
+ if is_hip():
120
+ logger.info(
121
+ "FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
122
+ )
123
+ from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
@@ -37,7 +37,7 @@ class LogitsProcessorOutput:
37
37
 
38
38
  # The normlaized logprobs of prompts. shape: [#seq]
39
39
  normalized_prompt_logprobs: torch.Tensor
40
- # The logprobs of input tokens. shape: [#token, vocab_size]
40
+ # The logprobs of input tokens. shape: [#token, vocab_size]
41
41
  input_token_logprobs: torch.Tensor
42
42
 
43
43
  # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
@@ -49,25 +49,39 @@ class LogitsProcessorOutput:
49
49
  @dataclasses.dataclass
50
50
  class LogitsMetadata:
51
51
  forward_mode: ForwardMode
52
+ top_logprobs_nums: Optional[List[int]]
53
+
52
54
  return_logprob: bool = False
55
+ return_top_logprob: bool = False
53
56
 
54
57
  extend_seq_lens: Optional[torch.Tensor] = None
55
- extend_start_loc: Optional[torch.Tensor] = None
56
- top_logprobs_nums: Optional[List[int]] = None
58
+ extend_seq_lens_cpu: Optional[List[int]] = None
57
59
 
58
- extend_seq_lens_cpu: List[int] = None
59
- logprob_start_lens_cpu: List[int] = None
60
+ extend_logprob_start_lens_cpu: Optional[List[int]] = None
61
+ extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
60
62
 
61
63
  @classmethod
62
64
  def from_input_metadata(cls, input_metadata: InputMetadata):
65
+ return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
66
+ if input_metadata.forward_mode.is_extend():
67
+ extend_logprob_pruned_lens_cpu = [
68
+ extend_len - start_len
69
+ for extend_len, start_len in zip(
70
+ input_metadata.extend_seq_lens,
71
+ input_metadata.extend_logprob_start_lens_cpu,
72
+ )
73
+ ]
74
+ else:
75
+ extend_logprob_pruned_lens_cpu = None
63
76
  return cls(
64
77
  forward_mode=input_metadata.forward_mode,
65
- extend_seq_lens=input_metadata.extend_seq_lens,
66
- extend_start_loc=input_metadata.extend_start_loc,
67
- return_logprob=input_metadata.return_logprob,
68
78
  top_logprobs_nums=input_metadata.top_logprobs_nums,
79
+ return_logprob=input_metadata.return_logprob,
80
+ return_top_logprob=return_top_logprob,
81
+ extend_seq_lens=input_metadata.extend_seq_lens,
69
82
  extend_seq_lens_cpu=input_metadata.extend_seq_lens_cpu,
70
- logprob_start_lens_cpu=input_metadata.logprob_start_lens_cpu,
83
+ extend_logprob_start_lens_cpu=input_metadata.extend_logprob_start_lens_cpu,
84
+ extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
71
85
  )
72
86
 
73
87
 
@@ -82,57 +96,49 @@ class LogitsProcessor(nn.Module):
82
96
  def _get_normalized_prompt_logprobs(
83
97
  self,
84
98
  input_token_logprobs: torch.Tensor,
85
- cum_start_len0: torch.Tensor,
86
- cum_start_len1: torch.Tensor,
87
99
  logits_metadata: LogitsMetadata,
88
100
  ):
89
101
  logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
102
+ pruned_lens = torch.tensor(
103
+ logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
104
+ )
90
105
 
91
- start = logits_metadata.extend_start_loc.clone() - cum_start_len0
92
- end = start + logits_metadata.extend_seq_lens - 2 - cum_start_len1
93
- start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
94
- end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
106
+ start = torch.zeros_like(pruned_lens)
107
+ start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
108
+ end = torch.clamp(
109
+ start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
110
+ )
95
111
  sum_logp = (
96
112
  logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
97
113
  )
98
- normalized_prompt_logprobs = sum_logp / (
99
- (logits_metadata.extend_seq_lens - 1).clamp(min=1)
100
- )
101
-
114
+ normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
102
115
  return normalized_prompt_logprobs
103
116
 
104
117
  @staticmethod
105
118
  def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
106
- if logits_metadata.forward_mode == ForwardMode.DECODE:
119
+ max_k = max(logits_metadata.top_logprobs_nums)
120
+ ret = all_logprobs.topk(max_k, dim=1)
121
+ values = ret.values.tolist()
122
+ indices = ret.indices.tolist()
123
+
124
+ if logits_metadata.forward_mode.is_decode():
107
125
  output_top_logprobs = []
108
- max_k = max(logits_metadata.top_logprobs_nums)
109
- ret = all_logprobs.topk(max_k, dim=1)
110
- values = ret.values.tolist()
111
- indices = ret.indices.tolist()
112
126
  for i, k in enumerate(logits_metadata.top_logprobs_nums):
113
127
  output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k])))
114
128
  return None, output_top_logprobs
115
129
  else:
116
- # TODO: vectorize the code below
117
130
  input_top_logprobs, output_top_logprobs = [], []
118
- pt = 0
119
- extend_seq_lens_cpu = logits_metadata.extend_seq_lens_cpu
120
131
 
121
- max_k = max(logits_metadata.top_logprobs_nums)
122
- ret = all_logprobs.topk(max_k, dim=1)
123
- values = ret.values.tolist()
124
- indices = ret.indices.tolist()
125
-
126
- for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
127
- start_len = logits_metadata.logprob_start_lens_cpu[i]
128
- pruned_len = extend_seq_len - start_len
129
-
130
- if extend_seq_len == 0:
132
+ pt = 0
133
+ for k, pruned_len in zip(
134
+ logits_metadata.top_logprobs_nums,
135
+ logits_metadata.extend_logprob_pruned_lens_cpu,
136
+ ):
137
+ if pruned_len <= 0:
131
138
  input_top_logprobs.append([])
132
139
  output_top_logprobs.append([])
133
140
  continue
134
141
 
135
- k = logits_metadata.top_logprobs_nums[i]
136
142
  input_top_logprobs.append(
137
143
  [
138
144
  list(zip(values[pt + j][:k], indices[pt + j][:k]))
@@ -163,14 +169,11 @@ class LogitsProcessor(nn.Module):
163
169
  assert isinstance(logits_metadata, LogitsMetadata)
164
170
 
165
171
  # Get the last hidden states and last logits for the next token prediction
166
- if logits_metadata.forward_mode == ForwardMode.DECODE:
172
+ if logits_metadata.forward_mode.is_decode():
167
173
  last_index = None
168
174
  last_hidden = hidden_states
169
175
  else:
170
- last_index = (
171
- torch.cumsum(logits_metadata.extend_seq_lens, dim=0, dtype=torch.long)
172
- - 1
173
- )
176
+ last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
174
177
  last_hidden = hidden_states[last_index]
175
178
 
176
179
  last_logits = torch.matmul(last_hidden, weight.T)
@@ -194,21 +197,15 @@ class LogitsProcessor(nn.Module):
194
197
  output_top_logprobs=None,
195
198
  )
196
199
  else:
197
- # When logprob is requested, compute the logits for all tokens.
198
- if logits_metadata.forward_mode == ForwardMode.DECODE:
199
- last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
200
+ last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
200
201
 
201
- # Get the logprob of top-k tokens
202
- return_top_logprob = any(
203
- x > 0 for x in logits_metadata.top_logprobs_nums
204
- )
205
- if return_top_logprob:
202
+ if logits_metadata.forward_mode.is_decode():
203
+ if logits_metadata.return_top_logprob:
206
204
  output_top_logprobs = self.get_top_logprobs(
207
205
  last_logprobs, logits_metadata
208
206
  )[1]
209
207
  else:
210
208
  output_top_logprobs = None
211
-
212
209
  return LogitsProcessorOutput(
213
210
  next_token_logits=last_logits,
214
211
  next_token_logprobs=last_logprobs,
@@ -218,22 +215,18 @@ class LogitsProcessor(nn.Module):
218
215
  output_top_logprobs=output_top_logprobs,
219
216
  )
220
217
  else:
218
+ # Slice the requested tokens to compute logprob
221
219
  pt, states, pruned_input_ids = 0, [], []
222
- for i, extend_len in enumerate(logits_metadata.extend_seq_lens_cpu):
223
- start_len = logits_metadata.logprob_start_lens_cpu[i]
220
+ for start_len, extend_len in zip(
221
+ logits_metadata.extend_logprob_start_lens_cpu,
222
+ logits_metadata.extend_seq_lens_cpu,
223
+ ):
224
224
  states.append(hidden_states[pt + start_len : pt + extend_len])
225
225
  pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
226
226
  pt += extend_len
227
227
 
228
+ # Compute the logits and logprobs for all required tokens
228
229
  states = torch.cat(states, dim=0)
229
- pruned_input_ids = torch.cat(pruned_input_ids, dim=0)
230
-
231
- cum_start_len1 = torch.tensor(
232
- logits_metadata.logprob_start_lens_cpu, device="cuda"
233
- ).cumsum(0)
234
- cum_start_len0 = torch.zeros_like(cum_start_len1)
235
- cum_start_len0[1:] = cum_start_len1[:-1]
236
-
237
230
  all_logits = torch.matmul(states, weight.T)
238
231
  if self.do_tensor_parallel_all_gather:
239
232
  all_logits = tensor_model_parallel_all_gather(all_logits)
@@ -249,35 +242,29 @@ class LogitsProcessor(nn.Module):
249
242
  all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
250
243
 
251
244
  # Get the logprob of top-k tokens
252
- return_top_logprob = any(
253
- x > 0 for x in logits_metadata.top_logprobs_nums
254
- )
255
- if return_top_logprob:
245
+ if logits_metadata.return_top_logprob:
256
246
  input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
257
247
  all_logprobs, logits_metadata
258
248
  )
259
249
  else:
260
250
  input_top_logprobs = output_top_logprobs = None
261
251
 
262
- last_logprobs = all_logprobs[last_index - cum_start_len1]
263
-
264
- # Compute the logprobs and normalized logprobs for the prefill tokens.
265
- # Note that we pad a zero at the end of each sequence for easy computation.
252
+ # Compute the normalized logprobs for the requested tokens.
253
+ # Note that we pad a zero at the end for easy batching.
266
254
  input_token_logprobs = all_logprobs[
267
255
  torch.arange(all_logprobs.shape[0], device="cuda"),
268
- torch.cat([pruned_input_ids[1:], torch.tensor([0], device="cuda")]),
256
+ torch.cat(
257
+ [
258
+ torch.cat(pruned_input_ids)[1:],
259
+ torch.tensor([0], device="cuda"),
260
+ ]
261
+ ),
269
262
  ]
270
-
271
263
  normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
272
264
  input_token_logprobs,
273
- cum_start_len0,
274
- cum_start_len1,
275
265
  logits_metadata,
276
266
  )
277
267
 
278
- # Remove the last token logprob for the prefill tokens.
279
- input_token_logprobs = input_token_logprobs[:-1]
280
-
281
268
  return LogitsProcessorOutput(
282
269
  next_token_logits=last_logits,
283
270
  next_token_logprobs=last_logprobs,