sglang 0.3.4.post1__py3-none-any.whl → 0.3.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/configs/model_config.py +25 -2
- sglang/srt/constrained/fsm_cache.py +10 -3
- sglang/srt/hf_transformers_utils.py +14 -0
- sglang/srt/layers/attention/flashinfer_backend.py +5 -5
- sglang/srt/layers/logits_processor.py +5 -5
- sglang/srt/layers/rotary_embedding.py +15 -48
- sglang/srt/layers/sampler.py +51 -39
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +4 -0
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/schedule_batch.py +13 -3
- sglang/srt/managers/scheduler.py +8 -2
- sglang/srt/managers/tokenizer_manager.py +14 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +58 -21
- sglang/srt/mem_cache/memory_pool.py +10 -3
- sglang/srt/model_executor/cuda_graph_runner.py +29 -21
- sglang/srt/model_executor/forward_batch_info.py +6 -9
- sglang/srt/model_executor/model_runner.py +2 -2
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
- sglang/srt/sampling/sampling_params.py +5 -7
- sglang/srt/server.py +12 -0
- sglang/test/run_eval.py +2 -0
- sglang/test/srt/sampling/penaltylib/utils.py +1 -0
- sglang/test/test_utils.py +100 -3
- sglang/version.py +1 -1
- {sglang-0.3.4.post1.dist-info → sglang-0.3.4.post2.dist-info}/METADATA +13 -14
- {sglang-0.3.4.post1.dist-info → sglang-0.3.4.post2.dist-info}/RECORD +30 -30
- {sglang-0.3.4.post1.dist-info → sglang-0.3.4.post2.dist-info}/LICENSE +0 -0
- {sglang-0.3.4.post1.dist-info → sglang-0.3.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.3.4.post1.dist-info → sglang-0.3.4.post2.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
+
import logging
|
17
|
+
import os
|
16
18
|
from enum import IntEnum, auto
|
17
19
|
from typing import Optional
|
18
20
|
|
@@ -20,6 +22,8 @@ from transformers import PretrainedConfig
|
|
20
22
|
|
21
23
|
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
22
24
|
|
25
|
+
logger = logging.getLogger(__name__)
|
26
|
+
|
23
27
|
|
24
28
|
class AttentionArch(IntEnum):
|
25
29
|
MLA = auto()
|
@@ -46,10 +50,29 @@ class ModelConfig:
|
|
46
50
|
model_override_args=model_override_args,
|
47
51
|
)
|
48
52
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
53
|
+
derived_context_len = get_context_length(self.hf_text_config)
|
54
|
+
allow_long_context = os.environ.get(
|
55
|
+
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None
|
56
|
+
)
|
57
|
+
|
49
58
|
if context_length is not None:
|
50
|
-
|
59
|
+
if context_length > derived_context_len:
|
60
|
+
if allow_long_context:
|
61
|
+
logger.warning(
|
62
|
+
f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
|
63
|
+
f"This may lead to incorrect model outputs or CUDA errors."
|
64
|
+
)
|
65
|
+
self.context_len = context_length
|
66
|
+
else:
|
67
|
+
raise ValueError(
|
68
|
+
f"User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
|
69
|
+
f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config. "
|
70
|
+
f"To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
|
71
|
+
)
|
72
|
+
else:
|
73
|
+
self.context_len = context_length
|
51
74
|
else:
|
52
|
-
self.context_len =
|
75
|
+
self.context_len = derived_context_len
|
53
76
|
|
54
77
|
# Unify the config keys for hf_text_config
|
55
78
|
self.head_dim = getattr(
|
@@ -73,9 +73,16 @@ class FSMCache(BaseToolCache):
|
|
73
73
|
def init_value(self, key):
|
74
74
|
key_type, key_string = key
|
75
75
|
if key_type == "json":
|
76
|
-
|
77
|
-
|
78
|
-
|
76
|
+
try:
|
77
|
+
regex = build_regex_from_schema(
|
78
|
+
key_string,
|
79
|
+
whitespace_pattern=self.constrained_json_whitespace_pattern,
|
80
|
+
)
|
81
|
+
except NotImplementedError as e:
|
82
|
+
logger.warning(
|
83
|
+
f"skip invalid json schema: json_schema={key_string}, {e=}"
|
84
|
+
)
|
85
|
+
return None, key_string
|
79
86
|
elif key_type == "regex":
|
80
87
|
regex = key_string
|
81
88
|
else:
|
@@ -163,6 +163,8 @@ def get_tokenizer(
|
|
163
163
|
"Using a slow tokenizer. This might cause a significant "
|
164
164
|
"slowdown. Consider using a fast tokenizer instead."
|
165
165
|
)
|
166
|
+
|
167
|
+
attach_additional_stop_token_ids(tokenizer)
|
166
168
|
return tokenizer
|
167
169
|
|
168
170
|
|
@@ -181,4 +183,16 @@ def get_processor(
|
|
181
183
|
tokenizer_revision=tokenizer_revision,
|
182
184
|
**kwargs,
|
183
185
|
)
|
186
|
+
|
187
|
+
attach_additional_stop_token_ids(processor.tokenizer)
|
184
188
|
return processor
|
189
|
+
|
190
|
+
|
191
|
+
def attach_additional_stop_token_ids(tokenizer):
|
192
|
+
# Special handling for stop token <|eom_id|> generated by llama 3 tool use.
|
193
|
+
if "<|eom_id|>" in tokenizer.get_added_vocab():
|
194
|
+
tokenizer.additional_stop_token_ids = set(
|
195
|
+
[tokenizer.get_added_vocab()["<|eom_id|>"]]
|
196
|
+
)
|
197
|
+
else:
|
198
|
+
tokenizer.additional_stop_token_ids = None
|
@@ -337,7 +337,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
337
337
|
def update(
|
338
338
|
self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens
|
339
339
|
):
|
340
|
-
# Keep the signature for type checking
|
340
|
+
# Keep the signature for type checking. It will be assigned during runtime.
|
341
341
|
raise NotImplementedError()
|
342
342
|
|
343
343
|
def update_single_wrapper(
|
@@ -432,8 +432,8 @@ class FlashInferIndicesUpdaterDecode:
|
|
432
432
|
kv_start_idx,
|
433
433
|
):
|
434
434
|
bs = len(req_pool_indices)
|
435
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
435
436
|
kv_indptr = kv_indptr[: bs + 1]
|
436
|
-
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
437
437
|
kv_indices = torch.empty(
|
438
438
|
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
439
439
|
)
|
@@ -497,7 +497,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
497
497
|
self.update = self.update_single_wrapper
|
498
498
|
|
499
499
|
def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens):
|
500
|
-
# Keep the signature for type checking
|
500
|
+
# Keep the signature for type checking. It will be assigned during runtime.
|
501
501
|
raise NotImplementedError()
|
502
502
|
|
503
503
|
def update_single_wrapper(
|
@@ -589,8 +589,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
589
589
|
use_ragged,
|
590
590
|
):
|
591
591
|
bs = len(req_pool_indices)
|
592
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
592
593
|
kv_indptr = kv_indptr[: bs + 1]
|
593
|
-
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
594
594
|
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
595
595
|
create_flashinfer_kv_indices_triton[(bs,)](
|
596
596
|
self.req_to_token,
|
@@ -602,8 +602,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
602
602
|
self.max_context_len,
|
603
603
|
)
|
604
604
|
|
605
|
+
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
605
606
|
qo_indptr = qo_indptr[: bs + 1]
|
606
|
-
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
607
607
|
|
608
608
|
# extend part
|
609
609
|
if use_ragged:
|
@@ -33,17 +33,17 @@ class LogitsProcessorOutput:
|
|
33
33
|
# The logits of the next tokens. shape: [#seq, vocab_size]
|
34
34
|
next_token_logits: torch.Tensor
|
35
35
|
# The logprobs of the next tokens. shape: [#seq, vocab_size]
|
36
|
-
next_token_logprobs: torch.Tensor
|
36
|
+
next_token_logprobs: torch.Tensor = None
|
37
37
|
|
38
38
|
# The normlaized logprobs of prompts. shape: [#seq]
|
39
|
-
normalized_prompt_logprobs: torch.Tensor
|
39
|
+
normalized_prompt_logprobs: torch.Tensor = None
|
40
40
|
# The logprobs of input tokens. shape: [#token, vocab_size]
|
41
|
-
input_token_logprobs: torch.Tensor
|
41
|
+
input_token_logprobs: torch.Tensor = None
|
42
42
|
|
43
43
|
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
44
|
-
input_top_logprobs: List
|
44
|
+
input_top_logprobs: List = None
|
45
45
|
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
46
|
-
output_top_logprobs: List
|
46
|
+
output_top_logprobs: List = None
|
47
47
|
|
48
48
|
|
49
49
|
@dataclasses.dataclass
|
@@ -22,64 +22,33 @@ class MRotaryEmbedding:
|
|
22
22
|
|
23
23
|
@staticmethod
|
24
24
|
def get_input_positions(
|
25
|
-
input_tokens:
|
25
|
+
input_tokens: torch.Tensor,
|
26
26
|
image_grid_thw: Union[List[List[int]], torch.Tensor],
|
27
|
-
video_grid_thw: Union[List[List[int]], torch.Tensor],
|
28
|
-
image_token_id: int,
|
29
|
-
video_token_id: int,
|
30
27
|
vision_start_token_id: int,
|
31
|
-
vision_end_token_id: int,
|
32
28
|
spatial_merge_size: int,
|
33
29
|
context_len: int = 0,
|
34
|
-
extend_prefix_len: int = 0,
|
35
30
|
) -> Tuple[List[List[int]], int]:
|
36
31
|
"""Get mrope input positions and delta value."""
|
37
32
|
|
38
33
|
if isinstance(image_grid_thw, torch.Tensor):
|
39
34
|
image_grid_thw = image_grid_thw.tolist()
|
40
|
-
if isinstance(video_grid_thw, torch.Tensor):
|
41
|
-
video_grid_thw = video_grid_thw.tolist()
|
42
35
|
|
43
|
-
input_tokens_tensor = torch.tensor(input_tokens)
|
44
36
|
vision_start_indices = torch.argwhere(
|
45
|
-
|
37
|
+
input_tokens == vision_start_token_id
|
46
38
|
).squeeze(1)
|
47
|
-
|
48
|
-
image_nums =
|
49
|
-
video_nums = (vision_tokens == video_token_id).sum()
|
39
|
+
image_indices = vision_start_indices + 1
|
40
|
+
image_nums = image_indices.shape[0]
|
50
41
|
llm_pos_ids_list: list = []
|
51
42
|
|
52
43
|
st = 0
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
if video_token_id in input_tokens and remain_videos > 0:
|
62
|
-
ed_video = input_tokens.index(video_token_id, st)
|
63
|
-
else:
|
64
|
-
ed_video = len(input_tokens) + 1
|
65
|
-
if ed_image < ed_video:
|
66
|
-
t, h, w = (
|
67
|
-
image_grid_thw[image_index][0],
|
68
|
-
image_grid_thw[image_index][1],
|
69
|
-
image_grid_thw[image_index][2],
|
70
|
-
)
|
71
|
-
image_index += 1
|
72
|
-
remain_images -= 1
|
73
|
-
ed = ed_image
|
74
|
-
else:
|
75
|
-
t, h, w = (
|
76
|
-
video_grid_thw[video_index][0],
|
77
|
-
video_grid_thw[video_index][1],
|
78
|
-
video_grid_thw[video_index][2],
|
79
|
-
)
|
80
|
-
video_index += 1
|
81
|
-
remain_videos -= 1
|
82
|
-
ed = ed_video
|
44
|
+
input_tokens_len = input_tokens.shape[0]
|
45
|
+
for image_index in range(image_nums):
|
46
|
+
ed = image_indices[image_index].item()
|
47
|
+
t, h, w = (
|
48
|
+
image_grid_thw[image_index][0],
|
49
|
+
image_grid_thw[image_index][1],
|
50
|
+
image_grid_thw[image_index][2],
|
51
|
+
)
|
83
52
|
llm_grid_t, llm_grid_h, llm_grid_w = (
|
84
53
|
t,
|
85
54
|
h // spatial_merge_size,
|
@@ -115,18 +84,16 @@ class MRotaryEmbedding:
|
|
115
84
|
)
|
116
85
|
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
117
86
|
|
118
|
-
if st <
|
87
|
+
if st < input_tokens_len:
|
119
88
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
120
|
-
text_len =
|
89
|
+
text_len = input_tokens_len - st
|
121
90
|
llm_pos_ids_list.append(
|
122
91
|
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
123
92
|
)
|
124
93
|
|
125
94
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
126
95
|
llm_positions = llm_positions[:, context_len:]
|
127
|
-
mrope_position_delta = (llm_positions.max() + 1 -
|
128
|
-
llm_positions += extend_prefix_len
|
129
|
-
|
96
|
+
mrope_position_delta = (llm_positions.max() + 1 - input_tokens_len).item()
|
130
97
|
return llm_positions.tolist(), mrope_position_delta
|
131
98
|
|
132
99
|
@staticmethod
|
sglang/srt/layers/sampler.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
import logging
|
2
|
+
import os
|
2
3
|
from typing import Union
|
3
4
|
|
4
5
|
import torch
|
@@ -17,6 +18,11 @@ if is_flashinfer_available():
|
|
17
18
|
top_p_renorm_prob,
|
18
19
|
)
|
19
20
|
|
21
|
+
|
22
|
+
# Crash on warning if we are running CI tests
|
23
|
+
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
24
|
+
|
25
|
+
|
20
26
|
logger = logging.getLogger(__name__)
|
21
27
|
|
22
28
|
|
@@ -33,56 +39,62 @@ class Sampler(nn.Module):
|
|
33
39
|
if isinstance(logits, LogitsProcessorOutput):
|
34
40
|
logits = logits.next_token_logits
|
35
41
|
|
36
|
-
# Post process logits
|
37
42
|
logits = logits.contiguous()
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
if self.use_nan_detectioin and torch.any(torch.isnan(probs)):
|
44
|
-
logger.warning("Detected errors during sampling! NaN in the probability.")
|
45
|
-
probs = torch.where(
|
46
|
-
torch.isnan(probs), torch.full_like(probs, 1e-10), probs
|
43
|
+
|
44
|
+
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
|
45
|
+
logger.warning("Detected errors during sampling! NaN in the logits.")
|
46
|
+
logits = torch.where(
|
47
|
+
torch.isnan(logits), torch.full_like(logits, -1e5), logits
|
47
48
|
)
|
49
|
+
exit(1) if crash_on_warning else None
|
48
50
|
|
49
51
|
if sampling_info.is_all_greedy:
|
50
52
|
# Use torch.argmax if all requests use greedy sampling
|
51
|
-
batch_next_token_ids = torch.argmax(
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
53
|
+
batch_next_token_ids = torch.argmax(logits, -1)
|
54
|
+
else:
|
55
|
+
# Post process logits
|
56
|
+
logits.div_(sampling_info.temperatures)
|
57
|
+
probs = torch.softmax(logits, dim=-1)
|
58
|
+
logits = None
|
59
|
+
del logits
|
60
|
+
|
61
|
+
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
62
|
+
max_top_k_round, batch_size = 32, probs.shape[0]
|
63
|
+
uniform_samples = torch.rand(
|
64
|
+
(max_top_k_round, batch_size), device=probs.device
|
62
65
|
)
|
63
|
-
|
64
|
-
|
66
|
+
if sampling_info.need_min_p_sampling:
|
67
|
+
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
68
|
+
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
69
|
+
batch_next_token_ids, success = min_p_sampling_from_probs(
|
70
|
+
probs, uniform_samples, sampling_info.min_ps
|
71
|
+
)
|
72
|
+
else:
|
73
|
+
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
74
|
+
probs,
|
75
|
+
uniform_samples,
|
76
|
+
sampling_info.top_ks,
|
77
|
+
sampling_info.top_ps,
|
78
|
+
filter_apply_order="joint",
|
79
|
+
)
|
80
|
+
|
81
|
+
if not torch.all(success):
|
82
|
+
logger.warning("Detected errors during sampling!")
|
83
|
+
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
84
|
+
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
85
|
+
# A slower fallback implementation with torch native operations.
|
86
|
+
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
65
87
|
probs,
|
66
|
-
uniform_samples,
|
67
88
|
sampling_info.top_ks,
|
68
89
|
sampling_info.top_ps,
|
69
|
-
|
90
|
+
sampling_info.min_ps,
|
91
|
+
)
|
92
|
+
else:
|
93
|
+
raise ValueError(
|
94
|
+
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
70
95
|
)
|
71
96
|
|
72
|
-
|
73
|
-
logger.warning("Detected errors during sampling!")
|
74
|
-
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
75
|
-
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
76
|
-
# Here we provide a slower fallback implementation.
|
77
|
-
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
78
|
-
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
79
|
-
)
|
80
|
-
else:
|
81
|
-
raise ValueError(
|
82
|
-
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
83
|
-
)
|
84
|
-
|
85
|
-
return batch_next_token_ids
|
97
|
+
return batch_next_token_ids.to(torch.int32)
|
86
98
|
|
87
99
|
|
88
100
|
def top_k_top_p_min_p_sampling_from_probs_torch(
|
@@ -27,6 +27,7 @@ from sglang.srt.managers.io_struct import (
|
|
27
27
|
BatchEmbeddingOut,
|
28
28
|
BatchStrOut,
|
29
29
|
BatchTokenIDOut,
|
30
|
+
GetMemPoolSizeReqOutput,
|
30
31
|
UpdateWeightReqOutput,
|
31
32
|
)
|
32
33
|
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
|
@@ -111,6 +112,9 @@ class DetokenizerManager:
|
|
111
112
|
# If it is a weight update request, no detokenization is needed.
|
112
113
|
self.send_to_tokenizer.send_pyobj(recv_obj)
|
113
114
|
continue
|
115
|
+
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
|
116
|
+
self.send_to_tokenizer.send_pyobj(recv_obj)
|
117
|
+
continue
|
114
118
|
elif self.tokenizer is None:
|
115
119
|
# If the tokenizer is skipped, no detokenization is needed
|
116
120
|
self.send_to_tokenizer.send_pyobj(recv_obj)
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -334,15 +334,20 @@ class Req:
|
|
334
334
|
|
335
335
|
last_token_id = self.output_ids[-1]
|
336
336
|
|
337
|
-
matched_eos =
|
337
|
+
matched_eos = False
|
338
338
|
|
339
|
+
# Check stop token ids
|
340
|
+
if self.sampling_params.stop_token_ids:
|
341
|
+
matched_eos = last_token_id in self.sampling_params.stop_token_ids
|
339
342
|
if self.tokenizer is not None:
|
340
343
|
matched_eos |= last_token_id == self.tokenizer.eos_token_id
|
341
|
-
|
344
|
+
if self.tokenizer.additional_stop_token_ids:
|
345
|
+
matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids
|
342
346
|
if matched_eos and not self.sampling_params.ignore_eos:
|
343
347
|
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
344
348
|
return
|
345
349
|
|
350
|
+
# Check stop strings
|
346
351
|
if len(self.sampling_params.stop_strs) > 0:
|
347
352
|
tail_str = self.tokenizer.decode(
|
348
353
|
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
|
@@ -514,7 +519,12 @@ class ScheduleBatch:
|
|
514
519
|
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
|
515
520
|
|
516
521
|
if out_cache_loc is None:
|
517
|
-
|
522
|
+
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
|
523
|
+
logger.error(
|
524
|
+
f"{phase_str} out of memory. Try to lower your batch size.\n"
|
525
|
+
f"Try to allocate {num_tokens} tokens.\n"
|
526
|
+
f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
|
527
|
+
)
|
518
528
|
if self.tree_cache is not None:
|
519
529
|
self.tree_cache.pretty_print()
|
520
530
|
exit(1)
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -38,6 +38,8 @@ from sglang.srt.managers.io_struct import (
|
|
38
38
|
BatchEmbeddingOut,
|
39
39
|
BatchTokenIDOut,
|
40
40
|
FlushCacheReq,
|
41
|
+
GetMemPoolSizeReq,
|
42
|
+
GetMemPoolSizeReqOutput,
|
41
43
|
ProfileReq,
|
42
44
|
TokenizedEmbeddingReqInput,
|
43
45
|
TokenizedGenerateReqInput,
|
@@ -69,7 +71,6 @@ from sglang.srt.utils import (
|
|
69
71
|
is_generation_model,
|
70
72
|
is_multimodal_model,
|
71
73
|
kill_parent_process,
|
72
|
-
pytorch_profile,
|
73
74
|
set_random_seed,
|
74
75
|
suppress_other_loggers,
|
75
76
|
)
|
@@ -363,6 +364,10 @@ class Scheduler:
|
|
363
364
|
self.start_profile()
|
364
365
|
else:
|
365
366
|
self.stop_profile()
|
367
|
+
elif isinstance(recv_req, GetMemPoolSizeReq):
|
368
|
+
self.send_to_detokenizer.send_pyobj(
|
369
|
+
GetMemPoolSizeReqOutput(self.max_total_num_tokens)
|
370
|
+
)
|
366
371
|
else:
|
367
372
|
raise ValueError(f"Invalid request: {recv_req}")
|
368
373
|
|
@@ -416,7 +421,7 @@ class Scheduler:
|
|
416
421
|
)
|
417
422
|
|
418
423
|
# Truncate prompts that are too long
|
419
|
-
if len(req.origin_input_ids)
|
424
|
+
if len(req.origin_input_ids) > self.max_req_input_len:
|
420
425
|
logger.warning(
|
421
426
|
"Request length is longer than the KV cache pool size or "
|
422
427
|
"the max context length. Truncated!!!"
|
@@ -828,6 +833,7 @@ class Scheduler:
|
|
828
833
|
|
829
834
|
if self.enable_overlap:
|
830
835
|
logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
|
836
|
+
next_token_logprobs = logits_output.next_token_logprobs
|
831
837
|
else:
|
832
838
|
# Move next_token_ids and logprobs to cpu
|
833
839
|
if batch.return_logprob:
|
@@ -46,6 +46,8 @@ from sglang.srt.managers.io_struct import (
|
|
46
46
|
EmbeddingReqInput,
|
47
47
|
FlushCacheReq,
|
48
48
|
GenerateReqInput,
|
49
|
+
GetMemPoolSizeReq,
|
50
|
+
GetMemPoolSizeReqOutput,
|
49
51
|
ProfileReq,
|
50
52
|
RewardReqInput,
|
51
53
|
TokenizedEmbeddingReqInput,
|
@@ -531,6 +533,15 @@ class TokenizerManager:
|
|
531
533
|
req = ProfileReq.STOP_PROFILE
|
532
534
|
self.send_to_scheduler.send_pyobj(req)
|
533
535
|
|
536
|
+
async def get_memory_pool_size(self):
|
537
|
+
if self.to_create_loop:
|
538
|
+
self.create_handle_loop()
|
539
|
+
|
540
|
+
req = GetMemPoolSizeReq()
|
541
|
+
self.send_to_scheduler.send_pyobj(req)
|
542
|
+
self.mem_pool_size = asyncio.Future()
|
543
|
+
return await self.mem_pool_size
|
544
|
+
|
534
545
|
async def update_weights(
|
535
546
|
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
536
547
|
):
|
@@ -590,6 +601,9 @@ class TokenizerManager:
|
|
590
601
|
if isinstance(recv_obj, UpdateWeightReqOutput):
|
591
602
|
self.model_update_result.set_result(recv_obj)
|
592
603
|
continue
|
604
|
+
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
|
605
|
+
self.mem_pool_size.set_result(recv_obj)
|
606
|
+
continue
|
593
607
|
|
594
608
|
assert isinstance(
|
595
609
|
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
|
@@ -32,6 +32,15 @@ from sglang.srt.server_args import ServerArgs
|
|
32
32
|
logger = logging.getLogger(__name__)
|
33
33
|
|
34
34
|
|
35
|
+
@torch.compile(dynamic=True)
|
36
|
+
def resolve_future_token_ids(input_ids, future_token_ids_map):
|
37
|
+
input_ids[:] = torch.where(
|
38
|
+
input_ids < 0,
|
39
|
+
future_token_ids_map[torch.clamp(-input_ids, min=0)],
|
40
|
+
input_ids,
|
41
|
+
)
|
42
|
+
|
43
|
+
|
35
44
|
class TpModelWorkerClient:
|
36
45
|
"""A tensor parallel model worker."""
|
37
46
|
|
@@ -94,46 +103,69 @@ class TpModelWorkerClient:
|
|
94
103
|
while True:
|
95
104
|
self.has_inflight_batch = False
|
96
105
|
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
106
|
+
if not model_worker_batch:
|
107
|
+
break
|
97
108
|
self.has_inflight_batch = True
|
98
109
|
self.launch_event = threading.Event()
|
99
110
|
|
100
111
|
# Resolve future tokens in the input
|
101
112
|
input_ids = model_worker_batch.input_ids
|
102
|
-
input_ids
|
103
|
-
input_ids < 0,
|
104
|
-
self.future_token_ids_map[torch.clamp(-input_ids, min=0)],
|
105
|
-
input_ids,
|
106
|
-
)
|
113
|
+
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
107
114
|
|
108
115
|
# Run forward
|
109
116
|
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
110
117
|
model_worker_batch
|
111
118
|
)
|
112
|
-
self.launch_event.set()
|
113
119
|
|
114
120
|
# Update the future token ids map
|
115
121
|
bs = len(model_worker_batch.seq_lens)
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
122
|
+
self.future_token_ids_map[
|
123
|
+
future_token_ids_ct + 1 : future_token_ids_ct + bs + 1
|
124
|
+
] = next_token_ids
|
125
|
+
|
126
|
+
# Copy results to the CPU
|
127
|
+
if model_worker_batch.return_logprob:
|
128
|
+
logits_output.next_token_logprobs = logits_output.next_token_logprobs[
|
129
|
+
torch.arange(len(next_token_ids), device=self.device),
|
130
|
+
next_token_ids,
|
131
|
+
].to("cpu", non_blocking=True)
|
132
|
+
if logits_output.input_token_logprobs is not None:
|
133
|
+
logits_output.input_token_logprobs = (
|
134
|
+
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
135
|
+
)
|
136
|
+
logits_output.normalized_prompt_logprobs = (
|
137
|
+
logits_output.normalized_prompt_logprobs.to(
|
138
|
+
"cpu", non_blocking=True
|
139
|
+
)
|
140
|
+
)
|
126
141
|
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
127
142
|
copy_event = torch.cuda.Event(blocking=True)
|
128
143
|
copy_event.record()
|
129
|
-
|
144
|
+
|
145
|
+
self.launch_event.set()
|
146
|
+
self.copy_queue.put((copy_event, logits_output, next_token_ids))
|
130
147
|
|
131
148
|
def copy_thread_func(self):
|
132
149
|
while True:
|
133
|
-
copy_event, next_token_ids = self.copy_queue.get()
|
150
|
+
copy_event, logits_output, next_token_ids = self.copy_queue.get()
|
151
|
+
if not copy_event:
|
152
|
+
break
|
134
153
|
while not copy_event.query():
|
135
154
|
time.sleep(1e-5)
|
136
|
-
|
155
|
+
|
156
|
+
if logits_output.next_token_logprobs is not None:
|
157
|
+
logits_output.next_token_logprobs = (
|
158
|
+
logits_output.next_token_logprobs.tolist()
|
159
|
+
)
|
160
|
+
if logits_output.input_token_logprobs is not None:
|
161
|
+
logits_output.input_token_logprobs = (
|
162
|
+
logits_output.input_token_logprobs.tolist()
|
163
|
+
)
|
164
|
+
logits_output.normalized_prompt_logprobs = (
|
165
|
+
logits_output.normalized_prompt_logprobs.tolist()
|
166
|
+
)
|
167
|
+
|
168
|
+
self.output_queue.put((logits_output, next_token_ids.tolist()))
|
137
169
|
|
138
170
|
def resulve_batch_result(self, bid: int):
|
139
171
|
logits_output, next_token_ids = self.output_queue.get()
|
@@ -149,8 +181,9 @@ class TpModelWorkerClient:
|
|
149
181
|
# Allocate output future objects
|
150
182
|
bs = len(model_worker_batch.seq_lens)
|
151
183
|
future_next_token_ids = torch.arange(
|
152
|
-
-(self.future_token_ids_ct +
|
153
|
-
-(self.future_token_ids_ct),
|
184
|
+
-(self.future_token_ids_ct + 1),
|
185
|
+
-(self.future_token_ids_ct + 1 + bs),
|
186
|
+
-1,
|
154
187
|
dtype=torch.int32,
|
155
188
|
device=self.device,
|
156
189
|
)
|
@@ -170,3 +203,7 @@ class TpModelWorkerClient:
|
|
170
203
|
recv_req.model_path, recv_req.load_format
|
171
204
|
)
|
172
205
|
return success, message
|
206
|
+
|
207
|
+
def __delete__(self):
|
208
|
+
self.input_queue.put((None, None))
|
209
|
+
self.copy_queue.put((None, None, None))
|