sglang 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sglang/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.19"
1
+ __version__ = "0.1.21"
2
2
 
3
3
  # SGL API Components
4
4
  from sglang.api import (
@@ -12,7 +12,6 @@ from sglang.utils import http_request
12
12
 
13
13
 
14
14
  class RuntimeEndpoint(BaseBackend):
15
-
16
15
  def __init__(
17
16
  self,
18
17
  base_url: str,
@@ -38,7 +37,8 @@ class RuntimeEndpoint(BaseBackend):
38
37
  self.model_info = res.json()
39
38
 
40
39
  self.chat_template = get_chat_template_by_model_path(
41
- self.model_info["model_path"])
40
+ self.model_info["model_path"]
41
+ )
42
42
 
43
43
  def get_model_name(self):
44
44
  return self.model_info["model_path"]
@@ -124,7 +124,12 @@ class RuntimeEndpoint(BaseBackend):
124
124
  else:
125
125
  raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
126
126
 
127
- for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]:
127
+ for item in [
128
+ "return_logprob",
129
+ "logprob_start_len",
130
+ "top_logprobs_num",
131
+ "return_text_in_logprobs",
132
+ ]:
128
133
  value = getattr(sampling_params, item, None)
129
134
  if value is not None:
130
135
  data[item] = value
@@ -171,7 +176,12 @@ class RuntimeEndpoint(BaseBackend):
171
176
  else:
172
177
  raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
173
178
 
174
- for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]:
179
+ for item in [
180
+ "return_logprob",
181
+ "logprob_start_len",
182
+ "top_logprobs_num",
183
+ "return_text_in_logprobs",
184
+ ]:
175
185
  value = getattr(sampling_params, item, None)
176
186
  if value is not None:
177
187
  data[item] = value
sglang/bench_latency.py CHANGED
@@ -70,6 +70,7 @@ class BenchArgs:
70
70
 
71
71
  def load_model(server_args, tp_rank):
72
72
  suppress_other_loggers()
73
+ rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
73
74
 
74
75
  model_config = ModelConfig(path=server_args.model_path)
75
76
  model_runner = ModelRunner(
@@ -81,7 +82,7 @@ def load_model(server_args, tp_rank):
81
82
  nccl_port=28888,
82
83
  server_args=server_args,
83
84
  )
84
- print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
85
+ rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
85
86
  tokenizer = get_tokenizer(
86
87
  server_args.tokenizer_path,
87
88
  tokenizer_mode=server_args.tokenizer_mode,
@@ -201,7 +202,7 @@ def correctness_test(
201
202
 
202
203
  # Print
203
204
  for i in range(len(reqs)):
204
- print(tokenizer.decode(output_ids[i]))
205
+ rank_print(tokenizer.decode(output_ids[i]))
205
206
 
206
207
 
207
208
  def latency_test(
@@ -213,7 +214,7 @@ def latency_test(
213
214
 
214
215
  # Load the model
215
216
  model_runner, tokenizer = load_model(server_args, tp_rank)
216
- print(
217
+ rank_print(
217
218
  f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
218
219
  )
219
220
 
@@ -299,6 +300,8 @@ def main(server_args, bench_args):
299
300
  for proc in workers:
300
301
  proc.join()
301
302
 
303
+ proc.terminate()
304
+
302
305
 
303
306
  if __name__ == "__main__":
304
307
  parser = argparse.ArgumentParser()
sglang/global_config.py CHANGED
@@ -8,36 +8,42 @@ class GlobalConfig:
8
8
  # 2: output final text after every run
9
9
  self.verbosity = 0
10
10
 
11
+ # Default backend of the language
11
12
  self.default_backend = None
12
13
 
13
- # Output configs
14
+ # Runtime constants: Request dependency time due to network delay
15
+ self.request_dependency_delay = 0.02
16
+ self.wait_for_new_request_delay = 0.0006
17
+
18
+ # Runtime constants: New generation token ratio estimation
19
+ self.base_new_token_ratio = 0.4
20
+ self.base_min_new_token_ratio = 0.2
21
+ self.new_token_ratio_decay = 0.0001
22
+ self.new_token_ratio_recovery = 0.05
23
+
24
+ # Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
25
+ # This can improve the speed for large batch sizes during prefill.
26
+ self.layer_sync_threshold = 8192
27
+
28
+ # Runtime constants: others
29
+ self.num_continue_decode_steps = 10
30
+ self.flashinfer_workspace_size = 192 * 1024 * 1024
31
+
32
+ # Output tokenization configs
14
33
  self.skip_special_tokens_in_output = True
15
34
  self.spaces_between_special_tokens_in_out = True
16
35
 
17
- # Optimization configs
36
+ # Interpreter optimization configs
18
37
  self.eager_fill_image = False
19
38
  self.enable_precache_with_tracing = True
20
39
  self.enable_parallel_encoding = True
21
40
  self.enable_parallel_decoding = True
22
41
 
42
+ # Deprecated
23
43
  # Choices: ["no_adjust", "adjust_cache"]
24
44
  # no_adjust: Do not adjust the position embedding of KV cache.
25
45
  # adjust_cache: Adjust the position embedding of KV cache.
26
46
  self.concate_and_append_mode = "no_adjust"
27
47
 
28
- # Request dependency time due to network delay
29
- self.request_dependency_delay = 0.02
30
- self.wait_for_new_request_delay = 0.0006
31
-
32
- # New generation token ratio estimation
33
- self.base_new_token_ratio = 0.4
34
- self.base_min_new_token_ratio = 0.2
35
- self.new_token_ratio_decay = 0.0001
36
- self.new_token_ratio_recovery = 0.05
37
-
38
- # The threshold (number of tokens) to trigger layer-wise cuda sync.
39
- # This can improve the speed for large batch sizes during prefill.
40
- self.layer_sync_threshold = 8192
41
-
42
48
 
43
49
  global_config = GlobalConfig()
@@ -84,7 +84,7 @@ register_chat_template(
84
84
  "system": ("SYSTEM:", "\n"),
85
85
  "user": ("USER:", "\n"),
86
86
  "assistant": ("ASSISTANT:", "\n"),
87
- }
87
+ },
88
88
  )
89
89
  )
90
90
 
@@ -177,7 +177,7 @@ register_chat_template(
177
177
  "assistant": ("", "<|im_end|>\n"),
178
178
  },
179
179
  style=ChatTemplateStyle.PLAIN,
180
- stop_str=("<|im_end|>",)
180
+ stop_str=("<|im_end|>",),
181
181
  )
182
182
  )
183
183
 
sglang/lang/ir.py CHANGED
@@ -24,9 +24,9 @@ class SglSamplingParams:
24
24
  presence_penalty: float = 0.0
25
25
  ignore_eos: bool = False
26
26
  return_logprob: Optional[bool] = None
27
- logprob_start_len: Optional[int] = None,
28
- top_logprobs_num: Optional[int] = None,
29
- return_text_in_logprobs: Optional[bool] = None,
27
+ logprob_start_len: Optional[int] = (None,)
28
+ top_logprobs_num: Optional[int] = (None,)
29
+ return_text_in_logprobs: Optional[bool] = (None,)
30
30
 
31
31
  # for constrained generation, not included in to_xxx_kwargs
32
32
  dtype: Optional[str] = None
@@ -1,6 +1,5 @@
1
1
  """Radix attention."""
2
2
 
3
- import numpy as np
4
3
  import torch
5
4
  from flashinfer.cascade import merge_state
6
5
  from torch import nn
@@ -8,6 +7,7 @@ from torch import nn
8
7
  from sglang.global_config import global_config
9
8
  from sglang.srt.layers.extend_attention import extend_attention_fwd
10
9
  from sglang.srt.layers.token_attention import token_attention_fwd
10
+ from sglang.srt.managers.controller.infer_batch import global_server_args_dict
11
11
  from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
12
12
 
13
13
 
@@ -29,24 +29,14 @@ class RadixAttention(nn.Module):
29
29
  self.scaling = scaling
30
30
  self.layer_id = layer_id
31
31
 
32
- from sglang.srt.managers.controller.model_runner import global_server_args_dict
33
-
34
32
  if not global_server_args_dict.get("disable_flashinfer", False):
35
- self.prefill_forward = self.prefill_forward_flashinfer
36
- self.extend_forward = self.prefill_forward_flashinfer
33
+ self.extend_forward = self.extend_forward_flashinfer
37
34
  self.decode_forward = self.decode_forward_flashinfer
38
- # flashinfer now accepts float logit_cap argument
39
- self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
40
35
  else:
41
- self.prefill_forward = self.prefill_forward_triton
42
36
  self.extend_forward = self.extend_forward_triton
43
37
  self.decode_forward = self.decode_forward_triton
44
- self.logit_cap = logit_cap if logit_cap is not None else 0
45
38
 
46
- def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata):
47
- # In SGLang, we call both the typical "prefill" and "prefill with cache" as "extend".
48
- # See the extend_forward_xxx functions.
49
- raise NotImplementedError()
39
+ self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
50
40
 
51
41
  def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
52
42
  o = torch.empty_like(q)
@@ -60,13 +50,13 @@ class RadixAttention(nn.Module):
60
50
  input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
61
51
  input_metadata.req_to_token_pool.req_to_token,
62
52
  input_metadata.req_pool_indices,
63
- input_metadata.start_loc,
53
+ input_metadata.triton_start_loc,
64
54
  input_metadata.seq_lens,
65
- input_metadata.prefix_lens,
55
+ input_metadata.triton_prefix_lens,
66
56
  input_metadata.extend_start_loc,
67
57
  input_metadata.extend_seq_lens,
68
- input_metadata.max_seq_len,
69
- input_metadata.max_extend_len,
58
+ input_metadata.triton_max_seq_len,
59
+ input_metadata.triton_max_extend_len,
70
60
  sm_scale=self.scaling,
71
61
  logit_cap=self.logit_cap,
72
62
  )
@@ -84,10 +74,9 @@ class RadixAttention(nn.Module):
84
74
  o.view(-1, self.tp_q_head_num, self.head_dim),
85
75
  input_metadata.req_to_token_pool.req_to_token,
86
76
  input_metadata.req_pool_indices,
87
- input_metadata.start_loc,
77
+ input_metadata.triton_start_loc,
88
78
  input_metadata.seq_lens,
89
- input_metadata.max_seq_len,
90
- input_metadata.other_kv_index,
79
+ input_metadata.triton_max_seq_len,
91
80
  input_metadata.total_num_tokens,
92
81
  sm_scale=self.scaling,
93
82
  logit_cap=self.logit_cap,
@@ -95,7 +84,7 @@ class RadixAttention(nn.Module):
95
84
 
96
85
  return o
97
86
 
98
- def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
87
+ def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
99
88
  o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
100
89
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
101
90
  k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
@@ -105,7 +94,7 @@ class RadixAttention(nn.Module):
105
94
  logits_soft_cap=self.logit_cap,
106
95
  )
107
96
 
108
- if input_metadata.no_prefix:
97
+ if input_metadata.extend_no_prefix:
109
98
  o = o1
110
99
  else:
111
100
  o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
@@ -141,25 +130,13 @@ class RadixAttention(nn.Module):
141
130
  k = k.view(-1, self.tp_k_head_num, self.head_dim)
142
131
  v = v.view(-1, self.tp_v_head_num, self.head_dim)
143
132
 
144
- if input_metadata.forward_mode == ForwardMode.PREFILL:
145
- return self.prefill_forward(q, k, v, input_metadata)
146
- elif input_metadata.forward_mode == ForwardMode.EXTEND:
133
+ if input_metadata.forward_mode == ForwardMode.EXTEND:
147
134
  return self.extend_forward(q, k, v, input_metadata)
148
135
  elif input_metadata.forward_mode == ForwardMode.DECODE:
149
136
  return self.decode_forward(q, k, v, input_metadata)
150
137
 
151
138
  def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
152
139
  key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
140
+ key_buffer[input_metadata.out_cache_loc] = cache_k
153
141
  value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
154
- if input_metadata.out_cache_loc is not None:
155
- key_buffer[input_metadata.out_cache_loc] = cache_k
156
- value_buffer[input_metadata.out_cache_loc] = cache_v
157
- elif input_metadata.out_cache_cont_start is not None:
158
- key_buffer[
159
- input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end
160
- ] = cache_k
161
- value_buffer[
162
- input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end
163
- ] = cache_v
164
- else:
165
- raise RuntimeError()
142
+ value_buffer[input_metadata.out_cache_loc] = cache_v
@@ -107,7 +107,6 @@ def _fwd_kernel_stage2(
107
107
  stride_obs,
108
108
  stride_oh,
109
109
  stride_req_to_token_b,
110
- other_kv_index, # To fix a NAN issue
111
110
  kv_group_num: tl.constexpr,
112
111
  BLOCK_DMODEL: tl.constexpr,
113
112
  BLOCK_N: tl.constexpr,
@@ -138,7 +137,7 @@ def _fwd_kernel_stage2(
138
137
  + cur_batch_req_idx * stride_req_to_token_b
139
138
  + (start_n + offs_n),
140
139
  mask=(start_n + offs_n) < cur_batch_seq_len,
141
- other=other_kv_index,
140
+ other=0,
142
141
  )
143
142
 
144
143
  qk = tl.load(
@@ -250,7 +249,6 @@ def _token_softmax_reducev_fwd(
250
249
  b_req_idx,
251
250
  b_start_loc,
252
251
  b_seq_len,
253
- other_kv_index,
254
252
  ):
255
253
  BLOCK = 64
256
254
  batch, head = b_seq_len.shape[0], logics.shape[0]
@@ -277,7 +275,6 @@ def _token_softmax_reducev_fwd(
277
275
  o.stride(0),
278
276
  o.stride(1),
279
277
  req_to_tokens.stride(0),
280
- other_kv_index,
281
278
  )
282
279
  return
283
280
 
@@ -295,7 +292,6 @@ def _token_softmax_reducev_fwd(
295
292
  o.stride(0),
296
293
  o.stride(1),
297
294
  req_to_tokens.stride(0),
298
- other_kv_index,
299
295
  kv_group_num=kv_group_num,
300
296
  BLOCK_DMODEL=v_buffer.shape[-1],
301
297
  BLOCK_N=BLOCK,
@@ -315,9 +311,8 @@ def token_attention_fwd(
315
311
  b_start_loc,
316
312
  b_seq_len,
317
313
  max_len_in_batch,
318
- other_kv_index,
319
314
  total_num_tokens,
320
- sm_scale=None,
315
+ sm_scale,
321
316
  logit_cap=-1,
322
317
  att_m=None,
323
318
  ):
@@ -325,7 +320,6 @@ def token_attention_fwd(
325
320
  att_m = torch.empty(
326
321
  (q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
327
322
  )
328
- sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
329
323
 
330
324
  _token_att_m_fwd(
331
325
  q,
@@ -347,5 +341,4 @@ def token_attention_fwd(
347
341
  b_req_idx,
348
342
  b_start_loc,
349
343
  b_seq_len,
350
- other_kv_index,
351
344
  )
@@ -0,0 +1,196 @@
1
+ """Run the model with cuda graph."""
2
+
3
+ import bisect
4
+
5
+ import torch
6
+ from vllm.distributed.parallel_state import graph_capture
7
+
8
+ from sglang.global_config import global_config
9
+ from sglang.srt.layers.logits_processor import LogitProcessorOutput
10
+ from sglang.srt.managers.controller.infer_batch import (
11
+ Batch,
12
+ ForwardMode,
13
+ InputMetadata,
14
+ init_flashinfer_args,
15
+ )
16
+
17
+
18
+ class CudaGraphRunner:
19
+ def __init__(self, model_runner, max_batch_size_to_capture):
20
+ self.model_runner = model_runner
21
+ self.graphs = {}
22
+ self.input_buffers = {}
23
+ self.output_buffers = {}
24
+ self.flashinfer_handlers = {}
25
+ self.graph_memory_pool = None
26
+
27
+ # Common inputs
28
+ self.max_bs = max_batch_size_to_capture
29
+ self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
30
+ self.req_pool_indices = torch.zeros(
31
+ (self.max_bs,), dtype=torch.int32, device="cuda"
32
+ )
33
+ self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
34
+ self.position_ids_offsets = torch.zeros(
35
+ (self.max_bs,), dtype=torch.int32, device="cuda"
36
+ )
37
+ self.out_cache_loc = torch.zeros(
38
+ (self.max_bs,), dtype=torch.int32, device="cuda"
39
+ )
40
+
41
+ # FlashInfer inputs
42
+ self.flashinfer_workspace_buffer = (
43
+ self.model_runner.flashinfer_workspace_buffers[0]
44
+ )
45
+ self.flashinfer_kv_indptr = torch.zeros(
46
+ (self.max_bs + 1,), dtype=torch.int32, device="cuda"
47
+ )
48
+ self.flashinfer_kv_indices = torch.zeros(
49
+ (self.max_bs * model_runner.model_config.context_len,),
50
+ dtype=torch.int32,
51
+ device="cuda",
52
+ )
53
+ self.flashinfer_kv_last_page_len = torch.ones(
54
+ (self.max_bs,), dtype=torch.int32, device="cuda"
55
+ )
56
+
57
+ def can_run(self, batch_size):
58
+ return batch_size < self.max_bs
59
+
60
+ def capture(self, batch_size_list):
61
+ self.batch_size_list = batch_size_list
62
+ with graph_capture() as graph_capture_context:
63
+ self.stream = graph_capture_context.stream
64
+ for bs in batch_size_list:
65
+ (
66
+ graph,
67
+ input_buffers,
68
+ output_buffers,
69
+ flashinfer_handler,
70
+ ) = self.capture_one_batch_size(bs)
71
+ self.graphs[bs] = graph
72
+ self.input_buffers[bs] = input_buffers
73
+ self.output_buffers[bs] = output_buffers
74
+ self.flashinfer_handlers[bs] = flashinfer_handler
75
+
76
+ def capture_one_batch_size(self, bs):
77
+ from flashinfer import BatchDecodeWithPagedKVCacheWrapper
78
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
79
+
80
+ graph = torch.cuda.CUDAGraph()
81
+ stream = self.stream
82
+
83
+ # Common inputs
84
+ input_ids = self.input_ids[:bs]
85
+ req_pool_indices = self.req_pool_indices[:bs]
86
+ seq_lens = self.seq_lens[:bs]
87
+ position_ids_offsets = self.position_ids_offsets[:bs]
88
+ out_cache_loc = self.out_cache_loc[:bs]
89
+
90
+ # FlashInfer inputs
91
+ if not _grouped_size_compiled_for_decode_kernels(
92
+ self.model_runner.model_config.num_attention_heads
93
+ // self.model_runner.tp_size,
94
+ self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
95
+ ):
96
+ use_tensor_cores = True
97
+ else:
98
+ use_tensor_cores = False
99
+ flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
100
+ self.flashinfer_workspace_buffer,
101
+ "NHD",
102
+ use_cuda_graph=True,
103
+ use_tensor_cores=use_tensor_cores,
104
+ paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
105
+ paged_kv_indices_buffer=self.flashinfer_kv_indices,
106
+ paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
107
+ )
108
+ init_flashinfer_args(
109
+ ForwardMode.DECODE,
110
+ self.model_runner,
111
+ req_pool_indices,
112
+ seq_lens,
113
+ None,
114
+ flashinfer_decode_wrapper,
115
+ )
116
+
117
+ # Run and capture
118
+ def run_once():
119
+ input_metadata = InputMetadata.create(
120
+ self.model_runner,
121
+ forward_mode=ForwardMode.DECODE,
122
+ req_pool_indices=req_pool_indices,
123
+ seq_lens=seq_lens,
124
+ prefix_lens=None,
125
+ position_ids_offsets=position_ids_offsets,
126
+ out_cache_loc=out_cache_loc,
127
+ return_logprob=False,
128
+ top_logprobs_nums=0,
129
+ skip_flashinfer_init=True,
130
+ )
131
+ input_metadata.flashinfer_decode_wrapper = flashinfer_decode_wrapper
132
+ return self.model_runner.model.forward(
133
+ input_ids, input_metadata.positions, input_metadata
134
+ )
135
+
136
+ for _ in range(2):
137
+ run_once()
138
+
139
+ torch.cuda.synchronize()
140
+ with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
141
+ out = run_once()
142
+ torch.cuda.synchronize()
143
+ self.graph_memory_pool = graph.pool()
144
+ return graph, None, out, flashinfer_decode_wrapper
145
+
146
+ def replay(self, batch: Batch):
147
+ assert batch.out_cache_loc is not None
148
+ assert not batch.return_logprob
149
+ raw_bs = len(batch.reqs)
150
+
151
+ # Pad
152
+ index = bisect.bisect_left(self.batch_size_list, raw_bs)
153
+ bs = self.batch_size_list[index]
154
+ if bs != raw_bs:
155
+ self.seq_lens.zero_()
156
+ self.position_ids_offsets.fill_(1)
157
+ self.out_cache_loc.zero_()
158
+
159
+ # Common inputs
160
+ self.input_ids[:raw_bs] = batch.input_ids
161
+ self.req_pool_indices[:raw_bs] = batch.req_pool_indices
162
+ self.seq_lens[:raw_bs] = batch.seq_lens
163
+ self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
164
+ self.out_cache_loc[:raw_bs] = batch.out_cache_loc
165
+
166
+ # FlashInfer inputs
167
+ init_flashinfer_args(
168
+ ForwardMode.DECODE,
169
+ self.model_runner,
170
+ self.req_pool_indices[:bs],
171
+ self.seq_lens[:bs],
172
+ None,
173
+ self.flashinfer_handlers[bs],
174
+ )
175
+
176
+ # Replay
177
+ self.graphs[bs].replay()
178
+ output = self.output_buffers[bs]
179
+
180
+ # Unpad
181
+ if bs == raw_bs:
182
+ return output
183
+ else:
184
+ output = LogitProcessorOutput(
185
+ next_token_logits=output.next_token_logits[:raw_bs],
186
+ next_token_logprobs=output.next_token_logprobs[:raw_bs]
187
+ if output.next_token_logprobs is not None
188
+ else None,
189
+ normalized_prompt_logprobs=None,
190
+ prefill_token_logprobs=None,
191
+ prefill_top_logprobs=None,
192
+ decode_top_logprobs=output.decode_top_logprobs[:raw_bs]
193
+ if output.decode_top_logprobs is not None
194
+ else None,
195
+ )
196
+ return output