sglang 0.3.4.post1__py3-none-any.whl → 0.3.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. sglang/srt/configs/model_config.py +25 -2
  2. sglang/srt/constrained/fsm_cache.py +10 -3
  3. sglang/srt/hf_transformers_utils.py +14 -0
  4. sglang/srt/layers/attention/flashinfer_backend.py +5 -5
  5. sglang/srt/layers/logits_processor.py +5 -5
  6. sglang/srt/layers/rotary_embedding.py +15 -48
  7. sglang/srt/layers/sampler.py +51 -39
  8. sglang/srt/managers/data_parallel_controller.py +1 -1
  9. sglang/srt/managers/detokenizer_manager.py +4 -0
  10. sglang/srt/managers/io_struct.py +10 -0
  11. sglang/srt/managers/schedule_batch.py +13 -3
  12. sglang/srt/managers/scheduler.py +8 -2
  13. sglang/srt/managers/tokenizer_manager.py +14 -0
  14. sglang/srt/managers/tp_worker_overlap_thread.py +58 -21
  15. sglang/srt/mem_cache/memory_pool.py +10 -3
  16. sglang/srt/model_executor/cuda_graph_runner.py +29 -21
  17. sglang/srt/model_executor/forward_batch_info.py +6 -9
  18. sglang/srt/model_executor/model_runner.py +2 -2
  19. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
  20. sglang/srt/sampling/sampling_params.py +5 -7
  21. sglang/srt/server.py +12 -0
  22. sglang/test/run_eval.py +2 -0
  23. sglang/test/srt/sampling/penaltylib/utils.py +1 -0
  24. sglang/test/test_utils.py +100 -3
  25. sglang/version.py +1 -1
  26. {sglang-0.3.4.post1.dist-info → sglang-0.3.4.post2.dist-info}/METADATA +13 -14
  27. {sglang-0.3.4.post1.dist-info → sglang-0.3.4.post2.dist-info}/RECORD +30 -30
  28. {sglang-0.3.4.post1.dist-info → sglang-0.3.4.post2.dist-info}/LICENSE +0 -0
  29. {sglang-0.3.4.post1.dist-info → sglang-0.3.4.post2.dist-info}/WHEEL +0 -0
  30. {sglang-0.3.4.post1.dist-info → sglang-0.3.4.post2.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
+ import logging
17
+ import os
16
18
  from enum import IntEnum, auto
17
19
  from typing import Optional
18
20
 
@@ -20,6 +22,8 @@ from transformers import PretrainedConfig
20
22
 
21
23
  from sglang.srt.hf_transformers_utils import get_config, get_context_length
22
24
 
25
+ logger = logging.getLogger(__name__)
26
+
23
27
 
24
28
  class AttentionArch(IntEnum):
25
29
  MLA = auto()
@@ -46,10 +50,29 @@ class ModelConfig:
46
50
  model_override_args=model_override_args,
47
51
  )
48
52
  self.hf_text_config = get_hf_text_config(self.hf_config)
53
+ derived_context_len = get_context_length(self.hf_text_config)
54
+ allow_long_context = os.environ.get(
55
+ "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None
56
+ )
57
+
49
58
  if context_length is not None:
50
- self.context_len = context_length
59
+ if context_length > derived_context_len:
60
+ if allow_long_context:
61
+ logger.warning(
62
+ f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
63
+ f"This may lead to incorrect model outputs or CUDA errors."
64
+ )
65
+ self.context_len = context_length
66
+ else:
67
+ raise ValueError(
68
+ f"User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
69
+ f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config. "
70
+ f"To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
71
+ )
72
+ else:
73
+ self.context_len = context_length
51
74
  else:
52
- self.context_len = get_context_length(self.hf_text_config)
75
+ self.context_len = derived_context_len
53
76
 
54
77
  # Unify the config keys for hf_text_config
55
78
  self.head_dim = getattr(
@@ -73,9 +73,16 @@ class FSMCache(BaseToolCache):
73
73
  def init_value(self, key):
74
74
  key_type, key_string = key
75
75
  if key_type == "json":
76
- regex = build_regex_from_schema(
77
- key_string, whitespace_pattern=self.constrained_json_whitespace_pattern
78
- )
76
+ try:
77
+ regex = build_regex_from_schema(
78
+ key_string,
79
+ whitespace_pattern=self.constrained_json_whitespace_pattern,
80
+ )
81
+ except NotImplementedError as e:
82
+ logger.warning(
83
+ f"skip invalid json schema: json_schema={key_string}, {e=}"
84
+ )
85
+ return None, key_string
79
86
  elif key_type == "regex":
80
87
  regex = key_string
81
88
  else:
@@ -163,6 +163,8 @@ def get_tokenizer(
163
163
  "Using a slow tokenizer. This might cause a significant "
164
164
  "slowdown. Consider using a fast tokenizer instead."
165
165
  )
166
+
167
+ attach_additional_stop_token_ids(tokenizer)
166
168
  return tokenizer
167
169
 
168
170
 
@@ -181,4 +183,16 @@ def get_processor(
181
183
  tokenizer_revision=tokenizer_revision,
182
184
  **kwargs,
183
185
  )
186
+
187
+ attach_additional_stop_token_ids(processor.tokenizer)
184
188
  return processor
189
+
190
+
191
+ def attach_additional_stop_token_ids(tokenizer):
192
+ # Special handling for stop token <|eom_id|> generated by llama 3 tool use.
193
+ if "<|eom_id|>" in tokenizer.get_added_vocab():
194
+ tokenizer.additional_stop_token_ids = set(
195
+ [tokenizer.get_added_vocab()["<|eom_id|>"]]
196
+ )
197
+ else:
198
+ tokenizer.additional_stop_token_ids = None
@@ -337,7 +337,7 @@ class FlashInferIndicesUpdaterDecode:
337
337
  def update(
338
338
  self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens
339
339
  ):
340
- # Keep the signature for type checking, will be initialized during runtime
340
+ # Keep the signature for type checking. It will be assigned during runtime.
341
341
  raise NotImplementedError()
342
342
 
343
343
  def update_single_wrapper(
@@ -432,8 +432,8 @@ class FlashInferIndicesUpdaterDecode:
432
432
  kv_start_idx,
433
433
  ):
434
434
  bs = len(req_pool_indices)
435
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
435
436
  kv_indptr = kv_indptr[: bs + 1]
436
- kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
437
437
  kv_indices = torch.empty(
438
438
  paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
439
439
  )
@@ -497,7 +497,7 @@ class FlashInferIndicesUpdaterPrefill:
497
497
  self.update = self.update_single_wrapper
498
498
 
499
499
  def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens):
500
- # Keep the signature for type checking, will be initialized during runtime
500
+ # Keep the signature for type checking. It will be assigned during runtime.
501
501
  raise NotImplementedError()
502
502
 
503
503
  def update_single_wrapper(
@@ -589,8 +589,8 @@ class FlashInferIndicesUpdaterPrefill:
589
589
  use_ragged,
590
590
  ):
591
591
  bs = len(req_pool_indices)
592
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
592
593
  kv_indptr = kv_indptr[: bs + 1]
593
- kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
594
594
  kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
595
595
  create_flashinfer_kv_indices_triton[(bs,)](
596
596
  self.req_to_token,
@@ -602,8 +602,8 @@ class FlashInferIndicesUpdaterPrefill:
602
602
  self.max_context_len,
603
603
  )
604
604
 
605
+ qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
605
606
  qo_indptr = qo_indptr[: bs + 1]
606
- qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
607
607
 
608
608
  # extend part
609
609
  if use_ragged:
@@ -33,17 +33,17 @@ class LogitsProcessorOutput:
33
33
  # The logits of the next tokens. shape: [#seq, vocab_size]
34
34
  next_token_logits: torch.Tensor
35
35
  # The logprobs of the next tokens. shape: [#seq, vocab_size]
36
- next_token_logprobs: torch.Tensor
36
+ next_token_logprobs: torch.Tensor = None
37
37
 
38
38
  # The normlaized logprobs of prompts. shape: [#seq]
39
- normalized_prompt_logprobs: torch.Tensor
39
+ normalized_prompt_logprobs: torch.Tensor = None
40
40
  # The logprobs of input tokens. shape: [#token, vocab_size]
41
- input_token_logprobs: torch.Tensor
41
+ input_token_logprobs: torch.Tensor = None
42
42
 
43
43
  # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
44
- input_top_logprobs: List
44
+ input_top_logprobs: List = None
45
45
  # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
46
- output_top_logprobs: List
46
+ output_top_logprobs: List = None
47
47
 
48
48
 
49
49
  @dataclasses.dataclass
@@ -22,64 +22,33 @@ class MRotaryEmbedding:
22
22
 
23
23
  @staticmethod
24
24
  def get_input_positions(
25
- input_tokens: List[int],
25
+ input_tokens: torch.Tensor,
26
26
  image_grid_thw: Union[List[List[int]], torch.Tensor],
27
- video_grid_thw: Union[List[List[int]], torch.Tensor],
28
- image_token_id: int,
29
- video_token_id: int,
30
27
  vision_start_token_id: int,
31
- vision_end_token_id: int,
32
28
  spatial_merge_size: int,
33
29
  context_len: int = 0,
34
- extend_prefix_len: int = 0,
35
30
  ) -> Tuple[List[List[int]], int]:
36
31
  """Get mrope input positions and delta value."""
37
32
 
38
33
  if isinstance(image_grid_thw, torch.Tensor):
39
34
  image_grid_thw = image_grid_thw.tolist()
40
- if isinstance(video_grid_thw, torch.Tensor):
41
- video_grid_thw = video_grid_thw.tolist()
42
35
 
43
- input_tokens_tensor = torch.tensor(input_tokens)
44
36
  vision_start_indices = torch.argwhere(
45
- input_tokens_tensor == vision_start_token_id
37
+ input_tokens == vision_start_token_id
46
38
  ).squeeze(1)
47
- vision_tokens = input_tokens_tensor[vision_start_indices + 1]
48
- image_nums = (vision_tokens == image_token_id).sum()
49
- video_nums = (vision_tokens == video_token_id).sum()
39
+ image_indices = vision_start_indices + 1
40
+ image_nums = image_indices.shape[0]
50
41
  llm_pos_ids_list: list = []
51
42
 
52
43
  st = 0
53
- remain_images, remain_videos = image_nums, video_nums
54
-
55
- image_index, video_index = 0, 0
56
- for _ in range(image_nums + video_nums):
57
- if image_token_id in input_tokens and remain_images > 0:
58
- ed_image = input_tokens.index(image_token_id, st)
59
- else:
60
- ed_image = len(input_tokens) + 1
61
- if video_token_id in input_tokens and remain_videos > 0:
62
- ed_video = input_tokens.index(video_token_id, st)
63
- else:
64
- ed_video = len(input_tokens) + 1
65
- if ed_image < ed_video:
66
- t, h, w = (
67
- image_grid_thw[image_index][0],
68
- image_grid_thw[image_index][1],
69
- image_grid_thw[image_index][2],
70
- )
71
- image_index += 1
72
- remain_images -= 1
73
- ed = ed_image
74
- else:
75
- t, h, w = (
76
- video_grid_thw[video_index][0],
77
- video_grid_thw[video_index][1],
78
- video_grid_thw[video_index][2],
79
- )
80
- video_index += 1
81
- remain_videos -= 1
82
- ed = ed_video
44
+ input_tokens_len = input_tokens.shape[0]
45
+ for image_index in range(image_nums):
46
+ ed = image_indices[image_index].item()
47
+ t, h, w = (
48
+ image_grid_thw[image_index][0],
49
+ image_grid_thw[image_index][1],
50
+ image_grid_thw[image_index][2],
51
+ )
83
52
  llm_grid_t, llm_grid_h, llm_grid_w = (
84
53
  t,
85
54
  h // spatial_merge_size,
@@ -115,18 +84,16 @@ class MRotaryEmbedding:
115
84
  )
116
85
  st = ed + llm_grid_t * llm_grid_h * llm_grid_w
117
86
 
118
- if st < len(input_tokens):
87
+ if st < input_tokens_len:
119
88
  st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
120
- text_len = len(input_tokens) - st
89
+ text_len = input_tokens_len - st
121
90
  llm_pos_ids_list.append(
122
91
  torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
123
92
  )
124
93
 
125
94
  llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
126
95
  llm_positions = llm_positions[:, context_len:]
127
- mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
128
- llm_positions += extend_prefix_len
129
-
96
+ mrope_position_delta = (llm_positions.max() + 1 - input_tokens_len).item()
130
97
  return llm_positions.tolist(), mrope_position_delta
131
98
 
132
99
  @staticmethod
@@ -1,4 +1,5 @@
1
1
  import logging
2
+ import os
2
3
  from typing import Union
3
4
 
4
5
  import torch
@@ -17,6 +18,11 @@ if is_flashinfer_available():
17
18
  top_p_renorm_prob,
18
19
  )
19
20
 
21
+
22
+ # Crash on warning if we are running CI tests
23
+ crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
24
+
25
+
20
26
  logger = logging.getLogger(__name__)
21
27
 
22
28
 
@@ -33,56 +39,62 @@ class Sampler(nn.Module):
33
39
  if isinstance(logits, LogitsProcessorOutput):
34
40
  logits = logits.next_token_logits
35
41
 
36
- # Post process logits
37
42
  logits = logits.contiguous()
38
- logits.div_(sampling_info.temperatures)
39
- probs = torch.softmax(logits, dim=-1)
40
- logits = None
41
- del logits
42
-
43
- if self.use_nan_detectioin and torch.any(torch.isnan(probs)):
44
- logger.warning("Detected errors during sampling! NaN in the probability.")
45
- probs = torch.where(
46
- torch.isnan(probs), torch.full_like(probs, 1e-10), probs
43
+
44
+ if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
45
+ logger.warning("Detected errors during sampling! NaN in the logits.")
46
+ logits = torch.where(
47
+ torch.isnan(logits), torch.full_like(logits, -1e5), logits
47
48
  )
49
+ exit(1) if crash_on_warning else None
48
50
 
49
51
  if sampling_info.is_all_greedy:
50
52
  # Use torch.argmax if all requests use greedy sampling
51
- batch_next_token_ids = torch.argmax(probs, -1)
52
- elif global_server_args_dict["sampling_backend"] == "flashinfer":
53
- max_top_k_round, batch_size = 32, probs.shape[0]
54
- uniform_samples = torch.rand(
55
- (max_top_k_round, batch_size), device=probs.device
56
- )
57
- if sampling_info.need_min_p_sampling:
58
- probs = top_k_renorm_prob(probs, sampling_info.top_ks)
59
- probs = top_p_renorm_prob(probs, sampling_info.top_ps)
60
- batch_next_token_ids, success = min_p_sampling_from_probs(
61
- probs, uniform_samples, sampling_info.min_ps
53
+ batch_next_token_ids = torch.argmax(logits, -1)
54
+ else:
55
+ # Post process logits
56
+ logits.div_(sampling_info.temperatures)
57
+ probs = torch.softmax(logits, dim=-1)
58
+ logits = None
59
+ del logits
60
+
61
+ if global_server_args_dict["sampling_backend"] == "flashinfer":
62
+ max_top_k_round, batch_size = 32, probs.shape[0]
63
+ uniform_samples = torch.rand(
64
+ (max_top_k_round, batch_size), device=probs.device
62
65
  )
63
- else:
64
- batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
66
+ if sampling_info.need_min_p_sampling:
67
+ probs = top_k_renorm_prob(probs, sampling_info.top_ks)
68
+ probs = top_p_renorm_prob(probs, sampling_info.top_ps)
69
+ batch_next_token_ids, success = min_p_sampling_from_probs(
70
+ probs, uniform_samples, sampling_info.min_ps
71
+ )
72
+ else:
73
+ batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
74
+ probs,
75
+ uniform_samples,
76
+ sampling_info.top_ks,
77
+ sampling_info.top_ps,
78
+ filter_apply_order="joint",
79
+ )
80
+
81
+ if not torch.all(success):
82
+ logger.warning("Detected errors during sampling!")
83
+ batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
84
+ elif global_server_args_dict["sampling_backend"] == "pytorch":
85
+ # A slower fallback implementation with torch native operations.
86
+ batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
65
87
  probs,
66
- uniform_samples,
67
88
  sampling_info.top_ks,
68
89
  sampling_info.top_ps,
69
- filter_apply_order="joint",
90
+ sampling_info.min_ps,
91
+ )
92
+ else:
93
+ raise ValueError(
94
+ f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
70
95
  )
71
96
 
72
- if not torch.all(success):
73
- logger.warning("Detected errors during sampling!")
74
- batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
75
- elif global_server_args_dict["sampling_backend"] == "pytorch":
76
- # Here we provide a slower fallback implementation.
77
- batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
78
- probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
79
- )
80
- else:
81
- raise ValueError(
82
- f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
83
- )
84
-
85
- return batch_next_token_ids
97
+ return batch_next_token_ids.to(torch.int32)
86
98
 
87
99
 
88
100
  def top_k_top_p_min_p_sampling_from_probs_torch(
@@ -156,7 +156,7 @@ class DataParallelController:
156
156
  else:
157
157
  # Send other control messages to all workers
158
158
  for worker in self.workers:
159
- worker.queue.put(recv_req)
159
+ worker.send_pyobj(recv_req)
160
160
 
161
161
 
162
162
  def run_data_parallel_controller_process(
@@ -27,6 +27,7 @@ from sglang.srt.managers.io_struct import (
27
27
  BatchEmbeddingOut,
28
28
  BatchStrOut,
29
29
  BatchTokenIDOut,
30
+ GetMemPoolSizeReqOutput,
30
31
  UpdateWeightReqOutput,
31
32
  )
32
33
  from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
@@ -111,6 +112,9 @@ class DetokenizerManager:
111
112
  # If it is a weight update request, no detokenization is needed.
112
113
  self.send_to_tokenizer.send_pyobj(recv_obj)
113
114
  continue
115
+ elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
116
+ self.send_to_tokenizer.send_pyobj(recv_obj)
117
+ continue
114
118
  elif self.tokenizer is None:
115
119
  # If the tokenizer is skipped, no detokenization is needed
116
120
  self.send_to_tokenizer.send_pyobj(recv_obj)
@@ -353,3 +353,13 @@ class AbortReq:
353
353
  class ProfileReq(Enum):
354
354
  START_PROFILE = 1
355
355
  STOP_PROFILE = 2
356
+
357
+
358
+ @dataclass
359
+ class GetMemPoolSizeReq:
360
+ pass
361
+
362
+
363
+ @dataclass
364
+ class GetMemPoolSizeReqOutput:
365
+ size: int
@@ -334,15 +334,20 @@ class Req:
334
334
 
335
335
  last_token_id = self.output_ids[-1]
336
336
 
337
- matched_eos = last_token_id in self.sampling_params.stop_token_ids
337
+ matched_eos = False
338
338
 
339
+ # Check stop token ids
340
+ if self.sampling_params.stop_token_ids:
341
+ matched_eos = last_token_id in self.sampling_params.stop_token_ids
339
342
  if self.tokenizer is not None:
340
343
  matched_eos |= last_token_id == self.tokenizer.eos_token_id
341
-
344
+ if self.tokenizer.additional_stop_token_ids:
345
+ matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids
342
346
  if matched_eos and not self.sampling_params.ignore_eos:
343
347
  self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
344
348
  return
345
349
 
350
+ # Check stop strings
346
351
  if len(self.sampling_params.stop_strs) > 0:
347
352
  tail_str = self.tokenizer.decode(
348
353
  self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
@@ -514,7 +519,12 @@ class ScheduleBatch:
514
519
  out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
515
520
 
516
521
  if out_cache_loc is None:
517
- logger.error("Prefill out of memory. Try to lower your batch size.")
522
+ phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
523
+ logger.error(
524
+ f"{phase_str} out of memory. Try to lower your batch size.\n"
525
+ f"Try to allocate {num_tokens} tokens.\n"
526
+ f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
527
+ )
518
528
  if self.tree_cache is not None:
519
529
  self.tree_cache.pretty_print()
520
530
  exit(1)
@@ -38,6 +38,8 @@ from sglang.srt.managers.io_struct import (
38
38
  BatchEmbeddingOut,
39
39
  BatchTokenIDOut,
40
40
  FlushCacheReq,
41
+ GetMemPoolSizeReq,
42
+ GetMemPoolSizeReqOutput,
41
43
  ProfileReq,
42
44
  TokenizedEmbeddingReqInput,
43
45
  TokenizedGenerateReqInput,
@@ -69,7 +71,6 @@ from sglang.srt.utils import (
69
71
  is_generation_model,
70
72
  is_multimodal_model,
71
73
  kill_parent_process,
72
- pytorch_profile,
73
74
  set_random_seed,
74
75
  suppress_other_loggers,
75
76
  )
@@ -363,6 +364,10 @@ class Scheduler:
363
364
  self.start_profile()
364
365
  else:
365
366
  self.stop_profile()
367
+ elif isinstance(recv_req, GetMemPoolSizeReq):
368
+ self.send_to_detokenizer.send_pyobj(
369
+ GetMemPoolSizeReqOutput(self.max_total_num_tokens)
370
+ )
366
371
  else:
367
372
  raise ValueError(f"Invalid request: {recv_req}")
368
373
 
@@ -416,7 +421,7 @@ class Scheduler:
416
421
  )
417
422
 
418
423
  # Truncate prompts that are too long
419
- if len(req.origin_input_ids) >= self.max_req_input_len:
424
+ if len(req.origin_input_ids) > self.max_req_input_len:
420
425
  logger.warning(
421
426
  "Request length is longer than the KV cache pool size or "
422
427
  "the max context length. Truncated!!!"
@@ -828,6 +833,7 @@ class Scheduler:
828
833
 
829
834
  if self.enable_overlap:
830
835
  logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
836
+ next_token_logprobs = logits_output.next_token_logprobs
831
837
  else:
832
838
  # Move next_token_ids and logprobs to cpu
833
839
  if batch.return_logprob:
@@ -46,6 +46,8 @@ from sglang.srt.managers.io_struct import (
46
46
  EmbeddingReqInput,
47
47
  FlushCacheReq,
48
48
  GenerateReqInput,
49
+ GetMemPoolSizeReq,
50
+ GetMemPoolSizeReqOutput,
49
51
  ProfileReq,
50
52
  RewardReqInput,
51
53
  TokenizedEmbeddingReqInput,
@@ -531,6 +533,15 @@ class TokenizerManager:
531
533
  req = ProfileReq.STOP_PROFILE
532
534
  self.send_to_scheduler.send_pyobj(req)
533
535
 
536
+ async def get_memory_pool_size(self):
537
+ if self.to_create_loop:
538
+ self.create_handle_loop()
539
+
540
+ req = GetMemPoolSizeReq()
541
+ self.send_to_scheduler.send_pyobj(req)
542
+ self.mem_pool_size = asyncio.Future()
543
+ return await self.mem_pool_size
544
+
534
545
  async def update_weights(
535
546
  self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
536
547
  ):
@@ -590,6 +601,9 @@ class TokenizerManager:
590
601
  if isinstance(recv_obj, UpdateWeightReqOutput):
591
602
  self.model_update_result.set_result(recv_obj)
592
603
  continue
604
+ elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
605
+ self.mem_pool_size.set_result(recv_obj)
606
+ continue
593
607
 
594
608
  assert isinstance(
595
609
  recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
@@ -32,6 +32,15 @@ from sglang.srt.server_args import ServerArgs
32
32
  logger = logging.getLogger(__name__)
33
33
 
34
34
 
35
+ @torch.compile(dynamic=True)
36
+ def resolve_future_token_ids(input_ids, future_token_ids_map):
37
+ input_ids[:] = torch.where(
38
+ input_ids < 0,
39
+ future_token_ids_map[torch.clamp(-input_ids, min=0)],
40
+ input_ids,
41
+ )
42
+
43
+
35
44
  class TpModelWorkerClient:
36
45
  """A tensor parallel model worker."""
37
46
 
@@ -94,46 +103,69 @@ class TpModelWorkerClient:
94
103
  while True:
95
104
  self.has_inflight_batch = False
96
105
  model_worker_batch, future_token_ids_ct = self.input_queue.get()
106
+ if not model_worker_batch:
107
+ break
97
108
  self.has_inflight_batch = True
98
109
  self.launch_event = threading.Event()
99
110
 
100
111
  # Resolve future tokens in the input
101
112
  input_ids = model_worker_batch.input_ids
102
- input_ids[:] = torch.where(
103
- input_ids < 0,
104
- self.future_token_ids_map[torch.clamp(-input_ids, min=0)],
105
- input_ids,
106
- )
113
+ resolve_future_token_ids(input_ids, self.future_token_ids_map)
107
114
 
108
115
  # Run forward
109
116
  logits_output, next_token_ids = self.worker.forward_batch_generation(
110
117
  model_worker_batch
111
118
  )
112
- self.launch_event.set()
113
119
 
114
120
  # Update the future token ids map
115
121
  bs = len(model_worker_batch.seq_lens)
116
- future_next_token_ids = torch.arange(
117
- -(future_token_ids_ct + bs),
118
- -(future_token_ids_ct),
119
- dtype=torch.int32,
120
- device=self.device,
121
- )
122
- self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
123
- torch.int32
124
- )
125
-
122
+ self.future_token_ids_map[
123
+ future_token_ids_ct + 1 : future_token_ids_ct + bs + 1
124
+ ] = next_token_ids
125
+
126
+ # Copy results to the CPU
127
+ if model_worker_batch.return_logprob:
128
+ logits_output.next_token_logprobs = logits_output.next_token_logprobs[
129
+ torch.arange(len(next_token_ids), device=self.device),
130
+ next_token_ids,
131
+ ].to("cpu", non_blocking=True)
132
+ if logits_output.input_token_logprobs is not None:
133
+ logits_output.input_token_logprobs = (
134
+ logits_output.input_token_logprobs.to("cpu", non_blocking=True)
135
+ )
136
+ logits_output.normalized_prompt_logprobs = (
137
+ logits_output.normalized_prompt_logprobs.to(
138
+ "cpu", non_blocking=True
139
+ )
140
+ )
126
141
  next_token_ids = next_token_ids.to("cpu", non_blocking=True)
127
142
  copy_event = torch.cuda.Event(blocking=True)
128
143
  copy_event.record()
129
- self.copy_queue.put((copy_event, next_token_ids))
144
+
145
+ self.launch_event.set()
146
+ self.copy_queue.put((copy_event, logits_output, next_token_ids))
130
147
 
131
148
  def copy_thread_func(self):
132
149
  while True:
133
- copy_event, next_token_ids = self.copy_queue.get()
150
+ copy_event, logits_output, next_token_ids = self.copy_queue.get()
151
+ if not copy_event:
152
+ break
134
153
  while not copy_event.query():
135
154
  time.sleep(1e-5)
136
- self.output_queue.put((None, next_token_ids.tolist()))
155
+
156
+ if logits_output.next_token_logprobs is not None:
157
+ logits_output.next_token_logprobs = (
158
+ logits_output.next_token_logprobs.tolist()
159
+ )
160
+ if logits_output.input_token_logprobs is not None:
161
+ logits_output.input_token_logprobs = (
162
+ logits_output.input_token_logprobs.tolist()
163
+ )
164
+ logits_output.normalized_prompt_logprobs = (
165
+ logits_output.normalized_prompt_logprobs.tolist()
166
+ )
167
+
168
+ self.output_queue.put((logits_output, next_token_ids.tolist()))
137
169
 
138
170
  def resulve_batch_result(self, bid: int):
139
171
  logits_output, next_token_ids = self.output_queue.get()
@@ -149,8 +181,9 @@ class TpModelWorkerClient:
149
181
  # Allocate output future objects
150
182
  bs = len(model_worker_batch.seq_lens)
151
183
  future_next_token_ids = torch.arange(
152
- -(self.future_token_ids_ct + bs),
153
- -(self.future_token_ids_ct),
184
+ -(self.future_token_ids_ct + 1),
185
+ -(self.future_token_ids_ct + 1 + bs),
186
+ -1,
154
187
  dtype=torch.int32,
155
188
  device=self.device,
156
189
  )
@@ -170,3 +203,7 @@ class TpModelWorkerClient:
170
203
  recv_req.model_path, recv_req.load_format
171
204
  )
172
205
  return success, message
206
+
207
+ def __delete__(self):
208
+ self.input_queue.put((None, None))
209
+ self.copy_queue.put((None, None, None))