sglang 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -253,14 +253,14 @@ class RuntimeEndpoint(BaseBackend):
253
253
  r["meta_info"]["normalized_prompt_logprob"] for r in obj
254
254
  ]
255
255
  decision = choices[np.argmax(normalized_prompt_logprobs)]
256
- prefill_token_logprobs = [r["meta_info"]["prefill_token_logprobs"] for r in obj]
257
- decode_token_logprobs = [r["meta_info"]["decode_token_logprobs"] for r in obj]
256
+ input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
257
+ output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
258
258
 
259
259
  return (
260
260
  decision,
261
261
  normalized_prompt_logprobs,
262
- prefill_token_logprobs,
263
- decode_token_logprobs,
262
+ input_token_logprobs,
263
+ output_token_logprobs,
264
264
  )
265
265
 
266
266
  def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
@@ -541,16 +541,16 @@ class StreamExecutor:
541
541
  (
542
542
  decision,
543
543
  normalized_prompt_logprobs,
544
- prefill_token_logprobs,
545
- decode_token_logprobs,
544
+ input_token_logprobs,
545
+ output_token_logprobs,
546
546
  ) = self.backend.select(self, expr.choices, expr.temperature)
547
547
  if expr.name is not None:
548
548
  name = expr.name
549
549
  self.variables[name] = decision
550
550
  self.meta_info[name] = {
551
551
  "normalized_prompt_logprobs": normalized_prompt_logprobs,
552
- "prefill_token_logprobs": prefill_token_logprobs,
553
- "decode_token_logprobs": decode_token_logprobs,
552
+ "input_token_logprobs": input_token_logprobs,
553
+ "output_token_logprobs": output_token_logprobs,
554
554
  }
555
555
  self.variable_event[name].set()
556
556
  self.text_ += decision
@@ -21,7 +21,27 @@ class FSMCache(BaseCache):
21
21
  tokenizer = AutoTokenizer.from_pretrained(
22
22
  tokenizer_path, **tokenizer_args_dict
23
23
  )
24
- self.outlines_tokenizer = TransformerTokenizer(tokenizer)
24
+ try:
25
+ self.outlines_tokenizer = TransformerTokenizer(tokenizer)
26
+ except AttributeError:
27
+ # FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
28
+ origin_pad_token_id = tokenizer.pad_token_id
29
+
30
+ def fset(self, value):
31
+ self._value = value
32
+
33
+ type(tokenizer).pad_token_id = property(
34
+ fget=type(tokenizer).pad_token_id.fget, fset=fset
35
+ )
36
+ self.outlines_tokenizer = TransformerTokenizer(tokenizer)
37
+ self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
38
+ self.outlines_tokenizer.pad_token_id = origin_pad_token_id
39
+ self.outlines_tokenizer.pad_token = (
40
+ self.outlines_tokenizer.tokenizer.pad_token
41
+ )
42
+ self.outlines_tokenizer.vocabulary = (
43
+ self.outlines_tokenizer.tokenizer.get_vocab()
44
+ )
25
45
  else:
26
46
  self.outlines_tokenizer = TransformerTokenizer(
27
47
  tokenizer_path, **tokenizer_args_dict
@@ -73,7 +73,9 @@ def get_context_length(config):
73
73
  rope_scaling = getattr(config, "rope_scaling", None)
74
74
  if rope_scaling:
75
75
  rope_scaling_factor = config.rope_scaling["factor"]
76
- if config.rope_scaling["rope_type"] == "llama3":
76
+ if "original_max_position_embeddings" in rope_scaling:
77
+ rope_scaling_factor = 1
78
+ if config.rope_scaling.get("rope_type", None) == "llama3":
77
79
  rope_scaling_factor = 1
78
80
  else:
79
81
  rope_scaling_factor = 1
@@ -1,7 +1,7 @@
1
1
  """Logits processing."""
2
2
 
3
3
  import dataclasses
4
- from typing import List, Union
4
+ from typing import List, Optional, Union
5
5
 
6
6
  import torch
7
7
  from torch import nn
@@ -22,23 +22,23 @@ class LogitProcessorOutput:
22
22
 
23
23
  # The normlaized logprobs of prompts. shape: [#seq]
24
24
  normalized_prompt_logprobs: torch.Tensor
25
- # The logprobs of prefill tokens. shape: [#token, vocab_size]
26
- prefill_token_logprobs: torch.Tensor
25
+ # The logprobs of input tokens. shape: [#token, vocab_size]
26
+ input_token_logprobs: torch.Tensor
27
27
 
28
- # The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
29
- prefill_top_logprobs: List
30
- # The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
31
- decode_top_logprobs: List
28
+ # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
29
+ input_top_logprobs: List
30
+ # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
31
+ output_top_logprobs: List
32
32
 
33
33
 
34
34
  @dataclasses.dataclass
35
35
  class LogitsMetadata:
36
36
  forward_mode: ForwardMode
37
- return_logprob: bool
37
+ return_logprob: bool = False
38
38
 
39
- extend_seq_lens: torch.Tensor = None
40
- extend_start_loc: torch.Tensor = None
41
- top_logprobs_nums: List[int] = None
39
+ extend_seq_lens: Optional[torch.Tensor] = None
40
+ extend_start_loc: Optional[torch.Tensor] = None
41
+ top_logprobs_nums: Optional[List[int]] = None
42
42
 
43
43
  @classmethod
44
44
  def from_input_metadata(cls, input_metadata: InputMetadata):
@@ -58,20 +58,16 @@ class LogitsProcessor(nn.Module):
58
58
  self.tp_size = get_tensor_model_parallel_world_size()
59
59
 
60
60
  def _get_normalized_prompt_logprobs(
61
- self, prefill_token_logprobs, logits_metadata: LogitsMetadata
61
+ self, input_token_logprobs, logits_metadata: LogitsMetadata
62
62
  ):
63
- logprobs_cumsum = torch.cumsum(
64
- prefill_token_logprobs, dim=0, dtype=torch.float32
65
- )
63
+ logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
66
64
 
67
65
  start = logits_metadata.extend_start_loc.clone()
68
66
  end = start + logits_metadata.extend_seq_lens - 2
69
- start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
70
- end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
67
+ start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
68
+ end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
71
69
  sum_logp = (
72
- logprobs_cumsum[end]
73
- - logprobs_cumsum[start]
74
- + prefill_token_logprobs[start]
70
+ logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
75
71
  )
76
72
  normalized_prompt_logprobs = sum_logp / (
77
73
  (logits_metadata.extend_seq_lens - 1).clamp(min=1)
@@ -79,37 +75,38 @@ class LogitsProcessor(nn.Module):
79
75
 
80
76
  return normalized_prompt_logprobs
81
77
 
82
- def _get_top_logprobs(self, all_logprobs, logits_metadata: LogitsMetadata):
78
+ @staticmethod
79
+ def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
83
80
  # TODO: vectorize the code below
84
81
  if logits_metadata.forward_mode == ForwardMode.DECODE:
85
- decode_top_logprobs = []
82
+ output_top_logprobs = []
86
83
  for i in range(all_logprobs.shape[0]):
87
84
  k = logits_metadata.top_logprobs_nums[i]
88
85
  t = all_logprobs[i].topk(k)
89
86
  v_cpu = t.values.tolist()
90
87
  p_cpu = t.indices.tolist()
91
- decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
92
- return None, decode_top_logprobs
88
+ output_top_logprobs.append(list(zip(v_cpu, p_cpu)))
89
+ return None, output_top_logprobs
93
90
  else:
94
- prefill_top_logprobs, decode_top_logprobs = [], []
91
+ input_top_logprobs, output_top_logprobs = [], []
95
92
  pt = 0
96
93
  extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
97
94
  for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
98
95
  if extend_seq_len == 0:
99
- prefill_top_logprobs.append([])
100
- decode_top_logprobs.append([])
96
+ input_top_logprobs.append([])
97
+ output_top_logprobs.append([])
101
98
  continue
102
99
  k = logits_metadata.top_logprobs_nums[i]
103
100
  t = all_logprobs[pt : pt + extend_seq_len].topk(k)
104
101
  vs_cpu = t.values.tolist()
105
102
  ps_cpu = t.indices.tolist()
106
- prefill_top_logprobs.append(
103
+ input_top_logprobs.append(
107
104
  [list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
108
105
  )
109
- decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
106
+ output_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
110
107
  pt += extend_seq_len
111
108
 
112
- return prefill_top_logprobs, decode_top_logprobs
109
+ return input_top_logprobs, output_top_logprobs
113
110
 
114
111
  def forward(
115
112
  self,
@@ -136,7 +133,7 @@ class LogitsProcessor(nn.Module):
136
133
  last_logits = torch.matmul(last_hidden, weight.T)
137
134
  if self.tp_size > 1:
138
135
  last_logits = tensor_model_parallel_all_gather(last_logits)
139
- last_logits = last_logits[:, : self.config.vocab_size]
136
+ last_logits = last_logits[:, : self.config.vocab_size].float()
140
137
 
141
138
  if hasattr(self.config, "final_logit_softcapping"):
142
139
  last_logits /= self.config.final_logit_softcapping
@@ -149,63 +146,75 @@ class LogitsProcessor(nn.Module):
149
146
  next_token_logits=last_logits,
150
147
  next_token_logprobs=None,
151
148
  normalized_prompt_logprobs=None,
152
- prefill_token_logprobs=None,
153
- prefill_top_logprobs=None,
154
- decode_top_logprobs=None,
149
+ input_token_logprobs=None,
150
+ input_top_logprobs=None,
151
+ output_top_logprobs=None,
155
152
  )
156
153
  else:
157
154
  # When logprob is requested, compute the logits for all tokens.
158
155
  if logits_metadata.forward_mode == ForwardMode.DECODE:
159
- all_logits = last_logits
160
- else:
161
- all_logits = torch.matmul(hidden_states, weight.T)
162
- if self.tp_size > 1:
163
- all_logits = tensor_model_parallel_all_gather(all_logits)
164
- all_logits = all_logits[:, : self.config.vocab_size]
165
-
166
- all_logprobs = all_logits.float()
167
- del all_logits
168
- all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
156
+ last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
169
157
 
170
- # Get the logprob of top-k tokens
171
- return_top_logprob = any(x > 0 for x in logits_metadata.top_logprobs_nums)
172
- if return_top_logprob:
173
- prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
174
- all_logprobs, logits_metadata
158
+ # Get the logprob of top-k tokens
159
+ return_top_logprob = any(
160
+ x > 0 for x in logits_metadata.top_logprobs_nums
175
161
  )
176
- else:
177
- prefill_top_logprobs = decode_top_logprobs = None
162
+ if return_top_logprob:
163
+ output_top_logprobs = self.get_top_logprobs(
164
+ last_logprobs, logits_metadata
165
+ )[1]
166
+ else:
167
+ output_top_logprobs = None
178
168
 
179
- if logits_metadata.forward_mode == ForwardMode.DECODE:
180
169
  return LogitProcessorOutput(
181
170
  next_token_logits=last_logits,
182
- next_token_logprobs=all_logprobs,
171
+ next_token_logprobs=last_logprobs,
183
172
  normalized_prompt_logprobs=None,
184
- prefill_token_logprobs=None,
185
- prefill_top_logprobs=None,
186
- decode_top_logprobs=decode_top_logprobs,
173
+ input_token_logprobs=None,
174
+ input_top_logprobs=None,
175
+ output_top_logprobs=output_top_logprobs,
187
176
  )
188
177
  else:
178
+ all_logits = torch.matmul(hidden_states, weight.T)
179
+ if self.tp_size > 1:
180
+ all_logits = tensor_model_parallel_all_gather(all_logits)
181
+ all_logits = all_logits[:, : self.config.vocab_size].float()
182
+
183
+ all_logprobs = all_logits
184
+ del all_logits
185
+ all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
186
+
187
+ # Get the logprob of top-k tokens
188
+ return_top_logprob = any(
189
+ x > 0 for x in logits_metadata.top_logprobs_nums
190
+ )
191
+ if return_top_logprob:
192
+ input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
193
+ all_logprobs, logits_metadata
194
+ )
195
+ else:
196
+ input_top_logprobs = output_top_logprobs = None
197
+
189
198
  last_logprobs = all_logprobs[last_index]
190
199
 
191
200
  # Compute the logprobs and normalized logprobs for the prefill tokens.
192
201
  # Note that we pad a zero at the end of each sequence for easy computation.
193
- prefill_token_logprobs = all_logprobs[
202
+ input_token_logprobs = all_logprobs[
194
203
  torch.arange(all_logprobs.shape[0], device="cuda"),
195
204
  torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
196
205
  ]
197
206
 
198
207
  normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
199
- prefill_token_logprobs, logits_metadata
208
+ input_token_logprobs, logits_metadata
200
209
  )
201
210
 
202
211
  return LogitProcessorOutput(
203
212
  next_token_logits=last_logits,
204
213
  next_token_logprobs=last_logprobs,
205
214
  normalized_prompt_logprobs=normalized_prompt_logprobs,
206
- prefill_token_logprobs=prefill_token_logprobs,
207
- prefill_top_logprobs=prefill_top_logprobs,
208
- decode_top_logprobs=decode_top_logprobs,
215
+ input_token_logprobs=input_token_logprobs,
216
+ input_top_logprobs=input_top_logprobs,
217
+ output_top_logprobs=output_top_logprobs,
209
218
  )
210
219
 
211
220
 
@@ -7,8 +7,11 @@ from torch import nn
7
7
  from sglang.global_config import global_config
8
8
  from sglang.srt.layers.extend_attention import extend_attention_fwd
9
9
  from sglang.srt.layers.token_attention import token_attention_fwd
10
- from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
11
- from sglang.srt.server import global_server_args_dict
10
+ from sglang.srt.managers.controller.model_runner import (
11
+ ForwardMode,
12
+ InputMetadata,
13
+ global_server_args_dict,
14
+ )
12
15
 
13
16
 
14
17
  class RadixAttention(nn.Module):
@@ -5,7 +5,7 @@ import torch
5
5
  import triton
6
6
  import triton.language as tl
7
7
 
8
- from sglang.srt.server import global_server_args_dict
8
+ from sglang.srt.managers.controller.infer_batch import global_server_args_dict
9
9
 
10
10
  if global_server_args_dict.get("attention_reduce_in_fp32", False):
11
11
  REDUCE_TRITON_TYPE = tl.float32
@@ -9,7 +9,11 @@ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
9
9
  from vllm.distributed.parallel_state import graph_capture
10
10
  from vllm.model_executor.custom_op import CustomOp
11
11
 
12
- from sglang.srt.layers.logits_processor import LogitProcessorOutput
12
+ from sglang.srt.layers.logits_processor import (
13
+ LogitProcessorOutput,
14
+ LogitsMetadata,
15
+ LogitsProcessor,
16
+ )
13
17
  from sglang.srt.managers.controller.infer_batch import (
14
18
  Batch,
15
19
  ForwardMode,
@@ -185,7 +189,6 @@ class CudaGraphRunner:
185
189
 
186
190
  def replay(self, batch: Batch):
187
191
  assert batch.out_cache_loc is not None
188
- assert not batch.return_logprob
189
192
  raw_bs = len(batch.reqs)
190
193
 
191
194
  # Pad
@@ -218,23 +221,29 @@ class CudaGraphRunner:
218
221
  output = self.output_buffers[bs]
219
222
 
220
223
  # Unpad
221
- if bs == raw_bs:
222
- return output
223
- else:
224
+ if bs != raw_bs:
224
225
  output = LogitProcessorOutput(
225
226
  next_token_logits=output.next_token_logits[:raw_bs],
226
- next_token_logprobs=(
227
- output.next_token_logprobs[:raw_bs]
228
- if output.next_token_logprobs is not None
229
- else None
230
- ),
227
+ next_token_logprobs=None,
231
228
  normalized_prompt_logprobs=None,
232
- prefill_token_logprobs=None,
233
- prefill_top_logprobs=None,
234
- decode_top_logprobs=(
235
- output.decode_top_logprobs[:raw_bs]
236
- if output.decode_top_logprobs is not None
237
- else None
238
- ),
229
+ input_token_logprobs=None,
230
+ input_top_logprobs=None,
231
+ output_top_logprobs=None,
239
232
  )
233
+
234
+ # Extract logprobs
235
+ if batch.return_logprob:
236
+ output.next_token_logprobs = torch.nn.functional.log_softmax(
237
+ output.next_token_logits, dim=-1
238
+ )
239
+ return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
240
+ if return_top_logprob:
241
+ logits_metadata = LogitsMetadata(
242
+ forward_mode=ForwardMode.DECODE,
243
+ top_logprobs_nums=batch.top_logprobs_nums,
244
+ )
245
+ output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
246
+ output.next_token_logprobs, logits_metadata
247
+ )[1]
248
+
240
249
  return output
@@ -17,6 +17,13 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
17
17
 
18
18
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
19
19
 
20
+ # Put some global args for easy access
21
+ global_server_args_dict = {
22
+ "disable_flashinfer": False,
23
+ "disable_flashinfer_sampling": False,
24
+ "attention_reduce_in_fp32": False,
25
+ }
26
+
20
27
 
21
28
  class ForwardMode(IntEnum):
22
29
  # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
@@ -124,10 +131,10 @@ class Req:
124
131
  self.logprob_start_len = 0
125
132
  self.top_logprobs_num = 0
126
133
  self.normalized_prompt_logprob = None
127
- self.prefill_token_logprobs = None
128
- self.prefill_top_logprobs = None
129
- self.decode_token_logprobs = []
130
- self.decode_top_logprobs = []
134
+ self.input_token_logprobs = None
135
+ self.input_top_logprobs = None
136
+ self.output_token_logprobs = []
137
+ self.output_top_logprobs = []
131
138
  # The tokens is prefilled but need to be considered as decode tokens
132
139
  # and should be updated for the decode logprobs
133
140
  self.last_update_decode_tokens = 0
@@ -244,8 +251,8 @@ class Req:
244
251
  k = k + 1
245
252
  else:
246
253
  break
247
- self.decode_token_logprobs = self.decode_token_logprobs[:k]
248
- self.decode_top_logprobs = self.decode_top_logprobs[:k]
254
+ self.output_token_logprobs = self.output_token_logprobs[:k]
255
+ self.output_top_logprobs = self.output_top_logprobs[:k]
249
256
  self.logprob_start_len = prompt_tokens + k
250
257
  self.last_update_decode_tokens = len(self.output_ids) - k
251
258
 
@@ -376,7 +383,7 @@ class Batch:
376
383
  logit_bias = torch.zeros(
377
384
  (bs, vocab_size), dtype=torch.float32, device=device
378
385
  )
379
- logit_bias[i] = int_token_logit_bias
386
+ logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
380
387
 
381
388
  # Set fields
382
389
  self.input_ids = torch.tensor(
@@ -687,13 +694,21 @@ class Batch:
687
694
  # TODO(lmzheng): apply penalty
688
695
  probs = torch.softmax(logits, dim=-1)
689
696
 
690
- max_top_k_round, batch_size = 32, probs.shape[0]
691
- uniform_samples = torch.rand((max_top_k_round, batch_size), device=probs.device)
692
- batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
693
- probs, uniform_samples, self.top_ks, self.top_ps
694
- )
697
+ if not global_server_args_dict["disable_flashinfer_sampling"]:
698
+ max_top_k_round, batch_size = 32, probs.shape[0]
699
+ uniform_samples = torch.rand(
700
+ (max_top_k_round, batch_size), device=probs.device
701
+ )
702
+ batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
703
+ probs, uniform_samples, self.top_ks, self.top_ps
704
+ )
705
+ else:
706
+ # Here we provide a slower fallback implementation.
707
+ batch_next_token_ids, success = top_k_top_p_sampling_from_probs_torch(
708
+ probs, self.top_ks, self.top_ps
709
+ )
695
710
 
696
- if torch.any(~success):
711
+ if not torch.all(success):
697
712
  warnings.warn("Sampling failed, fallback to top_k=1 strategy")
698
713
  probs = probs.masked_fill(torch.isnan(probs), 0.0)
699
714
  argmax_ids = torch.argmax(probs, dim=-1)
@@ -933,3 +948,29 @@ def init_triton_args(forward_mode, seq_lens, prefix_lens):
933
948
  max_extend_len = int(torch.max(extend_seq_lens))
934
949
 
935
950
  return max_seq_len, max_extend_len, start_loc, prefix_lens
951
+
952
+
953
+ def top_k_top_p_sampling_from_probs_torch(
954
+ probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
955
+ ):
956
+ """A top-k and top-k sampling implementation with native pytorch operations."""
957
+ probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
958
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
959
+ probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
960
+ probs_sort[
961
+ torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
962
+ >= top_ks.view(-1, 1)
963
+ ] = 0.0
964
+ probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
965
+ try:
966
+ sampled_index = torch.multinomial(probs_sort, num_samples=1)
967
+ except RuntimeError:
968
+ batch_next_token_ids = torch.zeros(
969
+ (probs_sort.shape[0],), dtype=torch.int64, device=probs.device
970
+ )
971
+ success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
972
+ return batch_next_token_ids, success
973
+
974
+ batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
975
+ success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
976
+ return batch_next_token_ids, success
@@ -25,7 +25,12 @@ from vllm.distributed import (
25
25
  from vllm.model_executor.models import ModelRegistry
26
26
 
27
27
  from sglang.global_config import global_config
28
- from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata
28
+ from sglang.srt.managers.controller.infer_batch import (
29
+ Batch,
30
+ ForwardMode,
31
+ InputMetadata,
32
+ global_server_args_dict,
33
+ )
29
34
  from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
30
35
  from sglang.srt.server_args import ServerArgs
31
36
  from sglang.srt.utils import (
@@ -60,7 +65,13 @@ class ModelRunner:
60
65
  self.nccl_port = nccl_port
61
66
  self.server_args = server_args
62
67
  self.is_multimodal_model = is_multimodal_model(self.model_config)
63
- monkey_patch_vllm_dummy_weight_loader()
68
+ global_server_args_dict.update(
69
+ {
70
+ "disable_flashinfer": server_args.disable_flashinfer,
71
+ "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
72
+ "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
73
+ }
74
+ )
64
75
 
65
76
  # Init torch distributed
66
77
  torch.cuda.set_device(self.gpu_id)
@@ -95,7 +106,7 @@ class ModelRunner:
95
106
 
96
107
  # Load the model and create memory pool
97
108
  self.load_model()
98
- self.init_memory_pool(total_gpu_memory)
109
+ self.init_memory_pool(total_gpu_memory, server_args.max_num_reqs)
99
110
  self.init_cublas()
100
111
  self.init_flash_infer()
101
112
 
@@ -108,6 +119,7 @@ class ModelRunner:
108
119
  f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
109
120
  )
110
121
 
122
+ monkey_patch_vllm_dummy_weight_loader()
111
123
  device_config = DeviceConfig()
112
124
  load_config = LoadConfig(load_format=self.server_args.load_format)
113
125
  vllm_model_config = VllmModelConfig(
@@ -176,7 +188,7 @@ class ModelRunner:
176
188
  max_num_token = int(rest_memory * (1 << 30) // cell_size)
177
189
  return max_num_token
178
190
 
179
- def init_memory_pool(self, total_gpu_memory):
191
+ def init_memory_pool(self, total_gpu_memory, max_num_reqs=None):
180
192
  self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
181
193
 
182
194
  if self.max_total_num_tokens <= 0:
@@ -184,11 +196,14 @@ class ModelRunner:
184
196
  "Not enough memory. Please try to increase --mem-fraction-static."
185
197
  )
186
198
 
187
- self.req_to_token_pool = ReqToTokenPool(
188
- max(
199
+ if max_num_reqs is None:
200
+ max_num_reqs = max(
189
201
  int(self.max_total_num_tokens / self.model_config.context_len * 512),
190
202
  2048,
191
- ),
203
+ )
204
+
205
+ self.req_to_token_pool = ReqToTokenPool(
206
+ max_num_reqs,
192
207
  self.model_config.context_len + 8,
193
208
  )
194
209
  self.token_to_kv_pool = TokenToKVPool(