sglang 0.1.18__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/api.py +26 -0
  3. sglang/backend/runtime_endpoint.py +18 -14
  4. sglang/bench_latency.py +40 -18
  5. sglang/global_config.py +21 -16
  6. sglang/lang/chat_template.py +41 -6
  7. sglang/lang/interpreter.py +5 -1
  8. sglang/lang/ir.py +61 -25
  9. sglang/srt/constrained/__init__.py +3 -2
  10. sglang/srt/hf_transformers_utils.py +7 -3
  11. sglang/srt/layers/extend_attention.py +2 -1
  12. sglang/srt/layers/fused_moe.py +181 -167
  13. sglang/srt/layers/logits_processor.py +55 -19
  14. sglang/srt/layers/radix_attention.py +33 -59
  15. sglang/srt/layers/token_attention.py +4 -8
  16. sglang/srt/managers/controller/cuda_graph_runner.py +172 -0
  17. sglang/srt/managers/controller/infer_batch.py +244 -36
  18. sglang/srt/managers/controller/manager_single.py +1 -1
  19. sglang/srt/managers/controller/model_runner.py +69 -284
  20. sglang/srt/managers/controller/tp_worker.py +39 -20
  21. sglang/srt/managers/detokenizer_manager.py +4 -2
  22. sglang/srt/managers/io_struct.py +1 -1
  23. sglang/srt/managers/tokenizer_manager.py +14 -13
  24. sglang/srt/memory_pool.py +33 -6
  25. sglang/srt/model_config.py +6 -0
  26. sglang/srt/models/gemma2.py +436 -0
  27. sglang/srt/models/llama2.py +3 -3
  28. sglang/srt/models/llama_classification.py +10 -7
  29. sglang/srt/models/minicpm.py +373 -0
  30. sglang/srt/models/qwen2_moe.py +454 -0
  31. sglang/srt/openai_api_adapter.py +2 -2
  32. sglang/srt/openai_protocol.py +1 -1
  33. sglang/srt/server.py +18 -8
  34. sglang/srt/server_args.py +24 -20
  35. sglang/srt/utils.py +68 -35
  36. {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/METADATA +19 -13
  37. {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/RECORD +40 -36
  38. {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/WHEEL +1 -1
  39. {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/LICENSE +0 -0
  40. {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/top_level.txt +0 -0
@@ -4,11 +4,9 @@ import importlib
4
4
  import importlib.resources
5
5
  import logging
6
6
  import pkgutil
7
- from dataclasses import dataclass
8
7
  from functools import lru_cache
9
- from typing import List, Optional, Type
8
+ from typing import Optional, Type
10
9
 
11
- import numpy as np
12
10
  import torch
13
11
  import torch.nn as nn
14
12
  from vllm.config import DeviceConfig, LoadConfig
@@ -17,7 +15,8 @@ from vllm.distributed import init_distributed_environment, initialize_model_para
17
15
  from vllm.model_executor.model_loader import get_model
18
16
  from vllm.model_executor.models import ModelRegistry
19
17
 
20
- from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
18
+ from sglang.global_config import global_config
19
+ from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata, global_server_args_dict
21
20
  from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
22
21
  from sglang.srt.server_args import ServerArgs
23
22
  from sglang.srt.utils import (
@@ -29,210 +28,6 @@ from sglang.srt.utils import (
29
28
 
30
29
  logger = logging.getLogger("srt.model_runner")
31
30
 
32
- # for server args in model endpoints
33
- global_server_args_dict = {}
34
-
35
-
36
- @dataclass
37
- class InputMetadata:
38
- forward_mode: ForwardMode
39
- batch_size: int
40
- total_num_tokens: int
41
- max_seq_len: int
42
- req_pool_indices: torch.Tensor
43
- start_loc: torch.Tensor
44
- seq_lens: torch.Tensor
45
- prefix_lens: torch.Tensor
46
- positions: torch.Tensor
47
- req_to_token_pool: ReqToTokenPool
48
- token_to_kv_pool: TokenToKVPool
49
-
50
- # for extend
51
- extend_seq_lens: torch.Tensor = None
52
- extend_start_loc: torch.Tensor = None
53
- max_extend_len: int = 0
54
-
55
- out_cache_loc: torch.Tensor = None
56
- out_cache_cont_start: torch.Tensor = None
57
- out_cache_cont_end: torch.Tensor = None
58
-
59
- other_kv_index: torch.Tensor = None
60
- return_logprob: bool = False
61
- top_logprobs_nums: List[int] = None
62
-
63
- # for flashinfer
64
- qo_indptr: torch.Tensor = None
65
- kv_indptr: torch.Tensor = None
66
- kv_indices: torch.Tensor = None
67
- kv_last_page_len: torch.Tensor = None
68
- flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
69
- flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
70
- flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
71
-
72
- def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
73
- if (
74
- self.forward_mode == ForwardMode.PREFILL
75
- or self.forward_mode == ForwardMode.EXTEND
76
- ):
77
- paged_kernel_lens = self.prefix_lens
78
- self.no_prefix = torch.all(self.prefix_lens == 0)
79
- else:
80
- paged_kernel_lens = self.seq_lens
81
-
82
- self.kv_indptr = torch.zeros(
83
- (self.batch_size + 1,), dtype=torch.int32, device="cuda"
84
- )
85
- self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
86
- self.kv_last_page_len = torch.ones(
87
- (self.batch_size,), dtype=torch.int32, device="cuda"
88
- )
89
- req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
90
- paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
91
- self.kv_indices = torch.cat(
92
- [
93
- self.req_to_token_pool.req_to_token[
94
- req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
95
- ]
96
- for i in range(self.batch_size)
97
- ],
98
- dim=0,
99
- ).contiguous()
100
-
101
- if (
102
- self.forward_mode == ForwardMode.PREFILL
103
- or self.forward_mode == ForwardMode.EXTEND
104
- ):
105
- # extend part
106
- self.qo_indptr = torch.zeros(
107
- (self.batch_size + 1,), dtype=torch.int32, device="cuda"
108
- )
109
- self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
110
-
111
- self.flashinfer_prefill_wrapper_ragged.end_forward()
112
- self.flashinfer_prefill_wrapper_ragged.begin_forward(
113
- self.qo_indptr,
114
- self.qo_indptr.clone(),
115
- num_qo_heads,
116
- num_kv_heads,
117
- head_dim,
118
- )
119
-
120
- # cached part
121
- self.flashinfer_prefill_wrapper_paged.end_forward()
122
- self.flashinfer_prefill_wrapper_paged.begin_forward(
123
- self.qo_indptr,
124
- self.kv_indptr,
125
- self.kv_indices,
126
- self.kv_last_page_len,
127
- num_qo_heads,
128
- num_kv_heads,
129
- head_dim,
130
- 1
131
- )
132
- else:
133
- self.flashinfer_decode_wrapper.end_forward()
134
- self.flashinfer_decode_wrapper.begin_forward(
135
- self.kv_indptr,
136
- self.kv_indices,
137
- self.kv_last_page_len,
138
- num_qo_heads,
139
- num_kv_heads,
140
- head_dim,
141
- 1,
142
- pos_encoding_mode="NONE",
143
- data_type=self.token_to_kv_pool.kv_data[0].dtype
144
- )
145
-
146
- def init_extend_args(self):
147
- self.extend_seq_lens = self.seq_lens - self.prefix_lens
148
- self.extend_start_loc = torch.zeros_like(self.seq_lens)
149
- self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
150
- self.max_extend_len = int(torch.max(self.extend_seq_lens))
151
-
152
- @classmethod
153
- def create(
154
- cls,
155
- model_runner,
156
- tp_size,
157
- forward_mode,
158
- req_pool_indices,
159
- seq_lens,
160
- prefix_lens,
161
- position_ids_offsets,
162
- out_cache_loc,
163
- out_cache_cont_start=None,
164
- out_cache_cont_end=None,
165
- top_logprobs_nums=None,
166
- return_logprob=False,
167
- flashinfer_prefill_wrapper_ragged=None,
168
- flashinfer_prefill_wrapper_paged=None,
169
- flashinfer_decode_wrapper=None,
170
- ):
171
- batch_size = len(req_pool_indices)
172
- start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
173
- start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
174
- total_num_tokens = int(torch.sum(seq_lens))
175
- max_seq_len = int(torch.max(seq_lens))
176
-
177
- if forward_mode == ForwardMode.DECODE:
178
- positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
179
- other_kv_index = model_runner.req_to_token_pool.req_to_token[
180
- req_pool_indices[0], seq_lens[0] - 1
181
- ].item()
182
- else:
183
- seq_lens_cpu = seq_lens.cpu().numpy()
184
- prefix_lens_cpu = prefix_lens.cpu().numpy()
185
- position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
186
- positions = torch.tensor(
187
- np.concatenate(
188
- [
189
- np.arange(
190
- prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
191
- seq_lens_cpu[i] + position_ids_offsets_cpu[i],
192
- )
193
- for i in range(batch_size)
194
- ],
195
- axis=0,
196
- ),
197
- device="cuda",
198
- )
199
- other_kv_index = None
200
-
201
- ret = cls(
202
- forward_mode=forward_mode,
203
- batch_size=batch_size,
204
- total_num_tokens=total_num_tokens,
205
- max_seq_len=max_seq_len,
206
- req_pool_indices=req_pool_indices,
207
- start_loc=start_loc,
208
- seq_lens=seq_lens,
209
- prefix_lens=prefix_lens,
210
- positions=positions,
211
- req_to_token_pool=model_runner.req_to_token_pool,
212
- token_to_kv_pool=model_runner.token_to_kv_pool,
213
- out_cache_loc=out_cache_loc,
214
- out_cache_cont_start=out_cache_cont_start,
215
- out_cache_cont_end=out_cache_cont_end,
216
- other_kv_index=other_kv_index,
217
- return_logprob=return_logprob,
218
- top_logprobs_nums=top_logprobs_nums,
219
- flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged,
220
- flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged,
221
- flashinfer_decode_wrapper=flashinfer_decode_wrapper,
222
- )
223
-
224
- if forward_mode == ForwardMode.EXTEND:
225
- ret.init_extend_args()
226
-
227
- if not global_server_args_dict.get("disable_flashinfer", False):
228
- ret.init_flashinfer_args(
229
- model_runner.model_config.num_attention_heads // tp_size,
230
- model_runner.model_config.get_num_kv_heads(tp_size),
231
- model_runner.model_config.head_dim
232
- )
233
-
234
- return ret
235
-
236
31
 
237
32
  class ModelRunner:
238
33
  def __init__(
@@ -245,6 +40,7 @@ class ModelRunner:
245
40
  nccl_port: int,
246
41
  server_args: ServerArgs,
247
42
  ):
43
+ # Parse args
248
44
  self.model_config = model_config
249
45
  self.mem_fraction_static = mem_fraction_static
250
46
  self.gpu_id = gpu_id
@@ -256,10 +52,12 @@ class ModelRunner:
256
52
  monkey_patch_vllm_dummy_weight_loader()
257
53
 
258
54
  # Init torch distributed
259
- logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
260
55
  torch.cuda.set_device(self.gpu_id)
261
56
  logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
262
- monkey_patch_vllm_p2p_access_check(self.gpu_id)
57
+
58
+ if not server_args.enable_p2p_check:
59
+ monkey_patch_vllm_p2p_access_check(self.gpu_id)
60
+
263
61
  if server_args.nccl_init_addr:
264
62
  nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
265
63
  else:
@@ -269,7 +67,7 @@ class ModelRunner:
269
67
  world_size=self.tp_size,
270
68
  rank=self.tp_rank,
271
69
  local_rank=self.gpu_id,
272
- distributed_init_method=nccl_init_method
70
+ distributed_init_method=nccl_init_method,
273
71
  )
274
72
  initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
275
73
  total_gpu_memory = get_available_gpu_memory(
@@ -284,11 +82,8 @@ class ModelRunner:
284
82
  )
285
83
 
286
84
  # Set some global args
287
- global global_server_args_dict
288
- global_server_args_dict = {
289
- "disable_flashinfer": server_args.disable_flashinfer,
290
- "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
291
- }
85
+ global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer
86
+ global_server_args_dict["attention_reduce_in_fp32"] = server_args.attention_reduce_in_fp32
292
87
 
293
88
  # Load the model and create memory pool
294
89
  self.load_model()
@@ -296,6 +91,9 @@ class ModelRunner:
296
91
  self.init_cublas()
297
92
  self.init_flash_infer()
298
93
 
94
+ # Capture cuda graphs
95
+ self.init_cuda_graphs()
96
+
299
97
  def load_model(self):
300
98
  logger.info(
301
99
  f"[gpu_id={self.gpu_id}] Load weight begin. "
@@ -323,7 +121,7 @@ class ModelRunner:
323
121
  device_config=device_config,
324
122
  load_config=load_config,
325
123
  lora_config=None,
326
- vision_language_config=None,
124
+ multimodal_config=None,
327
125
  parallel_config=None,
328
126
  scheduler_config=None,
329
127
  cache_config=None,
@@ -341,7 +139,13 @@ class ModelRunner:
341
139
  )
342
140
  head_dim = self.model_config.head_dim
343
141
  head_num = self.model_config.get_num_kv_heads(self.tp_size)
344
- cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * torch._utils._element_size(self.dtype)
142
+ cell_size = (
143
+ head_num
144
+ * head_dim
145
+ * self.model_config.num_hidden_layers
146
+ * 2
147
+ * torch._utils._element_size(self.dtype)
148
+ )
345
149
  rest_memory = available_gpu_memory - total_gpu_memory * (
346
150
  1 - self.mem_fraction_static
347
151
  )
@@ -382,64 +186,60 @@ class ModelRunner:
382
186
  return c
383
187
 
384
188
  def init_flash_infer(self):
385
- if not global_server_args_dict.get("disable_flashinfer", False):
386
- from flashinfer import (
387
- BatchPrefillWithRaggedKVCacheWrapper,
388
- BatchPrefillWithPagedKVCacheWrapper,
389
- BatchDecodeWithPagedKVCacheWrapper,
390
- )
391
- from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
189
+ if self.server_args.disable_flashinfer:
190
+ self.flashinfer_prefill_wrapper_ragged = None
191
+ self.flashinfer_prefill_wrapper_paged = None
192
+ self.flashinfer_decode_wrapper = None
193
+ return
392
194
 
393
- if not _grouped_size_compiled_for_decode_kernels(
394
- self.model_config.num_attention_heads // self.tp_size,
395
- self.model_config.get_num_kv_heads(self.tp_size)):
396
- use_tensor_cores = True
397
- else:
398
- use_tensor_cores = False
195
+ from flashinfer import (
196
+ BatchDecodeWithPagedKVCacheWrapper,
197
+ BatchPrefillWithPagedKVCacheWrapper,
198
+ BatchPrefillWithRaggedKVCacheWrapper,
199
+ )
200
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
399
201
 
400
- workspace_buffers = torch.empty(
401
- 3, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda"
402
- )
403
- self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
404
- workspace_buffers[0], "NHD"
405
- )
406
- self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
407
- workspace_buffers[1], "NHD"
408
- )
409
- self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
410
- workspace_buffers[2], "NHD", use_tensor_cores=use_tensor_cores
411
- )
202
+ if not _grouped_size_compiled_for_decode_kernels(
203
+ self.model_config.num_attention_heads // self.tp_size,
204
+ self.model_config.get_num_kv_heads(self.tp_size),
205
+ ):
206
+ use_tensor_cores = True
412
207
  else:
413
- self.flashinfer_prefill_wrapper_ragged = self.flashinfer_prefill_wrapper_paged = None
414
- self.flashinfer_decode_wrapper = None
208
+ use_tensor_cores = False
415
209
 
416
- @torch.inference_mode()
417
- def forward_prefill(self, batch: Batch):
418
- input_metadata = InputMetadata.create(
419
- self,
420
- forward_mode=ForwardMode.PREFILL,
421
- tp_size=self.tp_size,
422
- req_pool_indices=batch.req_pool_indices,
423
- seq_lens=batch.seq_lens,
424
- prefix_lens=batch.prefix_lens,
425
- position_ids_offsets=batch.position_ids_offsets,
426
- out_cache_loc=batch.out_cache_loc,
427
- top_logprobs_nums=batch.top_logprobs_nums,
428
- return_logprob=batch.return_logprob,
429
- flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
430
- flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
431
- flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
210
+ self.flashinfer_workspace_buffers = torch.empty(
211
+ 2, global_config.flashinfer_workspace_size, dtype=torch.uint8, device="cuda"
432
212
  )
433
- return self.model.forward(
434
- batch.input_ids, input_metadata.positions, input_metadata
213
+ self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
214
+ self.flashinfer_workspace_buffers[0], "NHD"
215
+ )
216
+ self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
217
+ self.flashinfer_workspace_buffers[1], "NHD"
435
218
  )
219
+ self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
220
+ self.flashinfer_workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores
221
+ )
222
+
223
+ def init_cuda_graphs(self):
224
+ from sglang.srt.managers.controller.cuda_graph_runner import CudaGraphRunner
225
+
226
+ if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
227
+ self.cuda_graph_runner = None
228
+ return
229
+
230
+ logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.")
231
+ batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 16)]
232
+ self.cuda_graph_runner = CudaGraphRunner(self, max_batch_size_to_capture=max(batch_size_list))
233
+ self.cuda_graph_runner.capture(batch_size_list)
436
234
 
437
235
  @torch.inference_mode()
438
- def forward_extend(self, batch: Batch):
236
+ def forward_decode(self, batch: Batch):
237
+ if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
238
+ return self.cuda_graph_runner.replay(batch)
239
+
439
240
  input_metadata = InputMetadata.create(
440
241
  self,
441
- forward_mode=ForwardMode.EXTEND,
442
- tp_size=self.tp_size,
242
+ forward_mode=ForwardMode.DECODE,
443
243
  req_pool_indices=batch.req_pool_indices,
444
244
  seq_lens=batch.seq_lens,
445
245
  prefix_lens=batch.prefix_lens,
@@ -447,32 +247,23 @@ class ModelRunner:
447
247
  out_cache_loc=batch.out_cache_loc,
448
248
  top_logprobs_nums=batch.top_logprobs_nums,
449
249
  return_logprob=batch.return_logprob,
450
- flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
451
- flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
452
- flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
453
250
  )
454
251
  return self.model.forward(
455
252
  batch.input_ids, input_metadata.positions, input_metadata
456
253
  )
457
254
 
458
255
  @torch.inference_mode()
459
- def forward_decode(self, batch: Batch):
256
+ def forward_extend(self, batch: Batch):
460
257
  input_metadata = InputMetadata.create(
461
258
  self,
462
- forward_mode=ForwardMode.DECODE,
463
- tp_size=self.tp_size,
259
+ forward_mode=ForwardMode.EXTEND,
464
260
  req_pool_indices=batch.req_pool_indices,
465
261
  seq_lens=batch.seq_lens,
466
262
  prefix_lens=batch.prefix_lens,
467
263
  position_ids_offsets=batch.position_ids_offsets,
468
264
  out_cache_loc=batch.out_cache_loc,
469
- out_cache_cont_start=batch.out_cache_cont_start,
470
- out_cache_cont_end=batch.out_cache_cont_end,
471
265
  top_logprobs_nums=batch.top_logprobs_nums,
472
266
  return_logprob=batch.return_logprob,
473
- flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
474
- flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
475
- flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
476
267
  )
477
268
  return self.model.forward(
478
269
  batch.input_ids, input_metadata.positions, input_metadata
@@ -483,17 +274,13 @@ class ModelRunner:
483
274
  input_metadata = InputMetadata.create(
484
275
  self,
485
276
  forward_mode=ForwardMode.EXTEND,
486
- tp_size=self.tp_size,
487
277
  req_pool_indices=batch.req_pool_indices,
488
278
  seq_lens=batch.seq_lens,
489
279
  prefix_lens=batch.prefix_lens,
490
280
  position_ids_offsets=batch.position_ids_offsets,
491
281
  out_cache_loc=batch.out_cache_loc,
492
- top_logprobs_nums=batch.top_logprobs_nums,
493
282
  return_logprob=batch.return_logprob,
494
- flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
495
- flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
496
- flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
283
+ top_logprobs_nums=batch.top_logprobs_nums,
497
284
  )
498
285
  return self.model.forward(
499
286
  batch.input_ids,
@@ -511,8 +298,6 @@ class ModelRunner:
511
298
  return self.forward_decode(batch)
512
299
  elif forward_mode == ForwardMode.EXTEND:
513
300
  return self.forward_extend(batch)
514
- elif forward_mode == ForwardMode.PREFILL:
515
- return self.forward_prefill(batch)
516
301
  else:
517
302
  raise ValueError(f"Invaid forward mode: {forward_mode}")
518
303
 
@@ -34,11 +34,11 @@ from sglang.srt.managers.io_struct import (
34
34
  from sglang.srt.model_config import ModelConfig
35
35
  from sglang.srt.server_args import ModelPortArgs, ServerArgs
36
36
  from sglang.srt.utils import (
37
+ connect_rpyc_service,
37
38
  get_int_token_logit_bias,
38
39
  is_multimodal_model,
39
40
  set_random_seed,
40
41
  start_rpyc_service_process,
41
- connect_rpyc_service,
42
42
  suppress_other_loggers,
43
43
  )
44
44
  from sglang.utils import get_exception_traceback
@@ -98,7 +98,7 @@ class ModelTpServer:
98
98
  )
99
99
  self.max_total_num_tokens = self.model_runner.max_total_num_tokens
100
100
  self.max_prefill_tokens = (
101
- 4096
101
+ 8192
102
102
  if server_args.max_prefill_tokens is None
103
103
  else server_args.max_prefill_tokens
104
104
  )
@@ -314,11 +314,9 @@ class ModelTpServer:
314
314
  self.forward_queue.append(req)
315
315
 
316
316
  def get_new_fill_batch(self) -> Optional[Batch]:
317
- if (
318
- self.running_batch is not None
319
- and len(self.running_batch.reqs) > self.max_running_requests
320
- ):
321
- return None
317
+ running_bs = len(self.running_batch.reqs) if self.running_batch is not None else 0
318
+ if running_bs >= self.max_running_requests:
319
+ return
322
320
 
323
321
  # Compute matched prefix length
324
322
  for req in self.forward_queue:
@@ -368,9 +366,11 @@ class ModelTpServer:
368
366
  if (
369
367
  req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
370
368
  < available_size
371
- and (req.extend_input_len + new_batch_input_tokens
372
- <= self.max_prefill_tokens
373
- or len(can_run_list) == 0)
369
+ and (
370
+ req.extend_input_len + new_batch_input_tokens
371
+ <= self.max_prefill_tokens
372
+ or len(can_run_list) == 0
373
+ )
374
374
  ):
375
375
  delta = self.tree_cache.inc_lock_ref(req.last_node)
376
376
  available_size += delta
@@ -392,6 +392,10 @@ class ModelTpServer:
392
392
  new_batch_input_tokens += req.extend_input_len
393
393
  else:
394
394
  break
395
+
396
+ if running_bs + len(can_run_list) >= self.max_running_requests:
397
+ break
398
+
395
399
  if len(can_run_list) == 0:
396
400
  return None
397
401
 
@@ -452,7 +456,9 @@ class ModelTpServer:
452
456
  next_token_ids,
453
457
  ].tolist()
454
458
  output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
455
- output.normalized_prompt_logprobs = output.normalized_prompt_logprobs.tolist()
459
+ output.normalized_prompt_logprobs = (
460
+ output.normalized_prompt_logprobs.tolist()
461
+ )
456
462
 
457
463
  next_token_ids = next_token_ids.tolist()
458
464
  else:
@@ -582,7 +588,9 @@ class ModelTpServer:
582
588
  req.check_finished()
583
589
 
584
590
  if req.return_logprob:
585
- req.decode_token_logprobs.append((next_token_logprobs[i], next_token_id))
591
+ req.decode_token_logprobs.append(
592
+ (next_token_logprobs[i], next_token_id)
593
+ )
586
594
  if req.top_logprobs_num > 0:
587
595
  req.decode_top_logprobs.append(output.decode_top_logprobs[i])
588
596
 
@@ -759,16 +767,27 @@ class ModelTpClient:
759
767
  with ThreadPoolExecutor(self.tp_size) as executor:
760
768
  # Launch model processes
761
769
  if server_args.nnodes == 1:
762
- self.procs = list(executor.map(
763
- lambda args: start_rpyc_service_process(*args),
764
- [(ModelTpService, p) for p in model_port_args.model_tp_ports],
765
- ))
770
+ self.procs = list(
771
+ executor.map(
772
+ lambda args: start_rpyc_service_process(*args),
773
+ [
774
+ (ModelTpService, p)
775
+ for p in model_port_args.model_tp_ports
776
+ ],
777
+ )
778
+ )
766
779
  addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
767
780
  else:
768
- addrs = [(ip, port) for ip, port in zip(model_port_args.model_tp_ips, model_port_args.model_tp_ports)]
769
-
770
- self.model_services = list(executor.map(
771
- lambda args: connect_rpyc_service(*args), addrs))
781
+ addrs = [
782
+ (ip, port)
783
+ for ip, port in zip(
784
+ model_port_args.model_tp_ips, model_port_args.model_tp_ports
785
+ )
786
+ ]
787
+
788
+ self.model_services = list(
789
+ executor.map(lambda args: connect_rpyc_service(*args), addrs)
790
+ )
772
791
 
773
792
  # Init model
774
793
  def init_model(i):
@@ -11,7 +11,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
11
11
  from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
12
12
  from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
13
13
  from sglang.srt.server_args import PortArgs, ServerArgs
14
- from sglang.utils import get_exception_traceback, graceful_registry
14
+ from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry
15
15
 
16
16
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
17
17
 
@@ -57,6 +57,8 @@ class DetokenizerManager:
57
57
  output_strs = []
58
58
  for i in range(len(recv_obj.rids)):
59
59
  new_text = read_texts[i][len(surr_texts[i]) :]
60
+ if recv_obj.finished_reason[i] is None:
61
+ new_text = find_printable_text(new_text)
60
62
  output_strs.append(recv_obj.decoded_texts[i] + new_text)
61
63
 
62
64
  if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
@@ -67,7 +69,7 @@ class DetokenizerManager:
67
69
  self.send_to_tokenizer.send_pyobj(
68
70
  BatchStrOut(
69
71
  rids=recv_obj.rids,
70
- output_str=output_strs,
72
+ output_strs=output_strs,
71
73
  meta_info=recv_obj.meta_info,
72
74
  finished_reason=recv_obj.finished_reason,
73
75
  )
@@ -122,7 +122,7 @@ class BatchTokenIDOut:
122
122
  @dataclass
123
123
  class BatchStrOut:
124
124
  rids: List[str]
125
- output_str: List[str]
125
+ output_strs: List[str]
126
126
  meta_info: List[Dict]
127
127
  finished_reason: List[BaseFinishReason]
128
128
 
@@ -316,7 +316,7 @@ class TokenizerManager:
316
316
 
317
317
  recv_obj.meta_info[i]["id"] = rid
318
318
  out_dict = {
319
- "text": recv_obj.output_str[i],
319
+ "text": recv_obj.output_strs[i],
320
320
  "meta_info": recv_obj.meta_info[i],
321
321
  }
322
322
  state.out_list.append(out_dict)
@@ -333,17 +333,18 @@ class TokenizerManager:
333
333
  ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
334
334
  ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
335
335
  )
336
- if top_logprobs_num > 0:
337
- ret["meta_info"][
338
- "prefill_top_logprobs"
339
- ] = self.detokenize_top_logprobs_tokens(
340
- ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
341
- )
342
- ret["meta_info"][
343
- "decode_top_logprobs"
344
- ] = self.detokenize_top_logprobs_tokens(
345
- ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
346
- )
336
+
337
+ if top_logprobs_num > 0:
338
+ ret["meta_info"][
339
+ "prefill_top_logprobs"
340
+ ] = self.detokenize_top_logprobs_tokens(
341
+ ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
342
+ )
343
+ ret["meta_info"][
344
+ "decode_top_logprobs"
345
+ ] = self.detokenize_top_logprobs_tokens(
346
+ ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
347
+ )
347
348
  return ret
348
349
 
349
350
  def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
@@ -383,7 +384,7 @@ def get_pixel_values(
383
384
  try:
384
385
  processor = processor or global_processor
385
386
  image, image_size = load_image(image_data)
386
- if image_size != None:
387
+ if image_size is not None:
387
388
  image_hash = hash(image_data)
388
389
  pixel_values = processor.image_processor(image)["pixel_values"]
389
390
  for _ in range(len(pixel_values)):