sglang 0.3.4.post1__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_latency.py +3 -3
- sglang/bench_server_latency.py +2 -3
- sglang/bench_serving.py +92 -0
- sglang/global_config.py +9 -3
- sglang/lang/chat_template.py +50 -25
- sglang/lang/interpreter.py +9 -1
- sglang/lang/ir.py +11 -2
- sglang/launch_server.py +1 -1
- sglang/srt/configs/model_config.py +76 -15
- sglang/srt/constrained/__init__.py +18 -0
- sglang/srt/constrained/bnf_cache.py +61 -0
- sglang/srt/constrained/fsm_cache.py +10 -3
- sglang/srt/constrained/grammar.py +190 -0
- sglang/srt/hf_transformers_utils.py +20 -5
- sglang/srt/layers/attention/flashinfer_backend.py +5 -5
- sglang/srt/layers/attention/triton_ops/decode_attention.py +110 -30
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
- sglang/srt/layers/fused_moe/fused_moe.py +4 -3
- sglang/srt/layers/fused_moe/layer.py +28 -0
- sglang/srt/layers/logits_processor.py +5 -5
- sglang/srt/layers/quantization/base_config.py +16 -1
- sglang/srt/layers/rotary_embedding.py +15 -48
- sglang/srt/layers/sampler.py +51 -39
- sglang/srt/layers/vocab_parallel_embedding.py +486 -0
- sglang/srt/managers/data_parallel_controller.py +8 -7
- sglang/srt/managers/detokenizer_manager.py +11 -9
- sglang/srt/managers/image_processor.py +4 -3
- sglang/srt/managers/io_struct.py +80 -78
- sglang/srt/managers/schedule_batch.py +46 -52
- sglang/srt/managers/schedule_policy.py +24 -13
- sglang/srt/managers/scheduler.py +145 -82
- sglang/srt/managers/tokenizer_manager.py +236 -334
- sglang/srt/managers/tp_worker.py +5 -5
- sglang/srt/managers/tp_worker_overlap_thread.py +58 -21
- sglang/srt/mem_cache/flush_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +10 -3
- sglang/srt/model_executor/cuda_graph_runner.py +34 -23
- sglang/srt/model_executor/forward_batch_info.py +6 -9
- sglang/srt/model_executor/model_runner.py +10 -19
- sglang/srt/models/baichuan.py +4 -4
- sglang/srt/models/chatglm.py +4 -4
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +5 -5
- sglang/srt/models/deepseek.py +4 -4
- sglang/srt/models/deepseek_v2.py +4 -4
- sglang/srt/models/exaone.py +4 -4
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -1
- sglang/srt/models/gpt2.py +287 -0
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +4 -4
- sglang/srt/models/internlm2.py +4 -4
- sglang/srt/models/llama.py +15 -7
- sglang/srt/models/llama_embedding.py +2 -10
- sglang/srt/models/llama_reward.py +5 -0
- sglang/srt/models/minicpm.py +4 -4
- sglang/srt/models/minicpm3.py +4 -4
- sglang/srt/models/mixtral.py +7 -5
- sglang/srt/models/mixtral_quant.py +4 -4
- sglang/srt/models/mllama.py +5 -5
- sglang/srt/models/olmo.py +4 -4
- sglang/srt/models/olmoe.py +4 -4
- sglang/srt/models/qwen.py +4 -4
- sglang/srt/models/qwen2.py +4 -4
- sglang/srt/models/qwen2_moe.py +4 -4
- sglang/srt/models/qwen2_vl.py +4 -8
- sglang/srt/models/stablelm.py +4 -4
- sglang/srt/models/torch_native_llama.py +4 -4
- sglang/srt/models/xverse.py +4 -4
- sglang/srt/models/xverse_moe.py +4 -4
- sglang/srt/openai_api/adapter.py +52 -66
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -13
- sglang/srt/sampling/sampling_params.py +5 -7
- sglang/srt/server.py +41 -33
- sglang/srt/server_args.py +34 -5
- sglang/srt/utils.py +40 -56
- sglang/test/run_eval.py +2 -0
- sglang/test/runners.py +2 -1
- sglang/test/srt/sampling/penaltylib/utils.py +1 -0
- sglang/test/test_utils.py +151 -6
- sglang/utils.py +62 -1
- sglang/version.py +1 -1
- sglang-0.3.5.dist-info/METADATA +344 -0
- sglang-0.3.5.dist-info/RECORD +152 -0
- {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/WHEEL +1 -1
- sglang-0.3.4.post1.dist-info/METADATA +0 -900
- sglang-0.3.4.post1.dist-info/RECORD +0 -148
- {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/LICENSE +0 -0
- {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/top_level.txt +0 -0
@@ -22,64 +22,33 @@ class MRotaryEmbedding:
|
|
22
22
|
|
23
23
|
@staticmethod
|
24
24
|
def get_input_positions(
|
25
|
-
input_tokens:
|
25
|
+
input_tokens: torch.Tensor,
|
26
26
|
image_grid_thw: Union[List[List[int]], torch.Tensor],
|
27
|
-
video_grid_thw: Union[List[List[int]], torch.Tensor],
|
28
|
-
image_token_id: int,
|
29
|
-
video_token_id: int,
|
30
27
|
vision_start_token_id: int,
|
31
|
-
vision_end_token_id: int,
|
32
28
|
spatial_merge_size: int,
|
33
29
|
context_len: int = 0,
|
34
|
-
extend_prefix_len: int = 0,
|
35
30
|
) -> Tuple[List[List[int]], int]:
|
36
31
|
"""Get mrope input positions and delta value."""
|
37
32
|
|
38
33
|
if isinstance(image_grid_thw, torch.Tensor):
|
39
34
|
image_grid_thw = image_grid_thw.tolist()
|
40
|
-
if isinstance(video_grid_thw, torch.Tensor):
|
41
|
-
video_grid_thw = video_grid_thw.tolist()
|
42
35
|
|
43
|
-
input_tokens_tensor = torch.tensor(input_tokens)
|
44
36
|
vision_start_indices = torch.argwhere(
|
45
|
-
|
37
|
+
input_tokens == vision_start_token_id
|
46
38
|
).squeeze(1)
|
47
|
-
|
48
|
-
image_nums =
|
49
|
-
video_nums = (vision_tokens == video_token_id).sum()
|
39
|
+
image_indices = vision_start_indices + 1
|
40
|
+
image_nums = image_indices.shape[0]
|
50
41
|
llm_pos_ids_list: list = []
|
51
42
|
|
52
43
|
st = 0
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
if video_token_id in input_tokens and remain_videos > 0:
|
62
|
-
ed_video = input_tokens.index(video_token_id, st)
|
63
|
-
else:
|
64
|
-
ed_video = len(input_tokens) + 1
|
65
|
-
if ed_image < ed_video:
|
66
|
-
t, h, w = (
|
67
|
-
image_grid_thw[image_index][0],
|
68
|
-
image_grid_thw[image_index][1],
|
69
|
-
image_grid_thw[image_index][2],
|
70
|
-
)
|
71
|
-
image_index += 1
|
72
|
-
remain_images -= 1
|
73
|
-
ed = ed_image
|
74
|
-
else:
|
75
|
-
t, h, w = (
|
76
|
-
video_grid_thw[video_index][0],
|
77
|
-
video_grid_thw[video_index][1],
|
78
|
-
video_grid_thw[video_index][2],
|
79
|
-
)
|
80
|
-
video_index += 1
|
81
|
-
remain_videos -= 1
|
82
|
-
ed = ed_video
|
44
|
+
input_tokens_len = input_tokens.shape[0]
|
45
|
+
for image_index in range(image_nums):
|
46
|
+
ed = image_indices[image_index].item()
|
47
|
+
t, h, w = (
|
48
|
+
image_grid_thw[image_index][0],
|
49
|
+
image_grid_thw[image_index][1],
|
50
|
+
image_grid_thw[image_index][2],
|
51
|
+
)
|
83
52
|
llm_grid_t, llm_grid_h, llm_grid_w = (
|
84
53
|
t,
|
85
54
|
h // spatial_merge_size,
|
@@ -115,18 +84,16 @@ class MRotaryEmbedding:
|
|
115
84
|
)
|
116
85
|
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
117
86
|
|
118
|
-
if st <
|
87
|
+
if st < input_tokens_len:
|
119
88
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
120
|
-
text_len =
|
89
|
+
text_len = input_tokens_len - st
|
121
90
|
llm_pos_ids_list.append(
|
122
91
|
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
123
92
|
)
|
124
93
|
|
125
94
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
126
95
|
llm_positions = llm_positions[:, context_len:]
|
127
|
-
mrope_position_delta = (llm_positions.max() + 1 -
|
128
|
-
llm_positions += extend_prefix_len
|
129
|
-
|
96
|
+
mrope_position_delta = (llm_positions.max() + 1 - input_tokens_len).item()
|
130
97
|
return llm_positions.tolist(), mrope_position_delta
|
131
98
|
|
132
99
|
@staticmethod
|
sglang/srt/layers/sampler.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
import logging
|
2
|
+
import os
|
2
3
|
from typing import Union
|
3
4
|
|
4
5
|
import torch
|
@@ -17,6 +18,11 @@ if is_flashinfer_available():
|
|
17
18
|
top_p_renorm_prob,
|
18
19
|
)
|
19
20
|
|
21
|
+
|
22
|
+
# Crash on warning if we are running CI tests
|
23
|
+
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
24
|
+
|
25
|
+
|
20
26
|
logger = logging.getLogger(__name__)
|
21
27
|
|
22
28
|
|
@@ -33,56 +39,62 @@ class Sampler(nn.Module):
|
|
33
39
|
if isinstance(logits, LogitsProcessorOutput):
|
34
40
|
logits = logits.next_token_logits
|
35
41
|
|
36
|
-
# Post process logits
|
37
42
|
logits = logits.contiguous()
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
if self.use_nan_detectioin and torch.any(torch.isnan(probs)):
|
44
|
-
logger.warning("Detected errors during sampling! NaN in the probability.")
|
45
|
-
probs = torch.where(
|
46
|
-
torch.isnan(probs), torch.full_like(probs, 1e-10), probs
|
43
|
+
|
44
|
+
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
|
45
|
+
logger.warning("Detected errors during sampling! NaN in the logits.")
|
46
|
+
logits = torch.where(
|
47
|
+
torch.isnan(logits), torch.full_like(logits, -1e5), logits
|
47
48
|
)
|
49
|
+
exit(1) if crash_on_warning else None
|
48
50
|
|
49
51
|
if sampling_info.is_all_greedy:
|
50
52
|
# Use torch.argmax if all requests use greedy sampling
|
51
|
-
batch_next_token_ids = torch.argmax(
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
53
|
+
batch_next_token_ids = torch.argmax(logits, -1)
|
54
|
+
else:
|
55
|
+
# Post process logits
|
56
|
+
logits.div_(sampling_info.temperatures)
|
57
|
+
probs = torch.softmax(logits, dim=-1)
|
58
|
+
logits = None
|
59
|
+
del logits
|
60
|
+
|
61
|
+
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
62
|
+
max_top_k_round, batch_size = 32, probs.shape[0]
|
63
|
+
uniform_samples = torch.rand(
|
64
|
+
(max_top_k_round, batch_size), device=probs.device
|
62
65
|
)
|
63
|
-
|
64
|
-
|
66
|
+
if sampling_info.need_min_p_sampling:
|
67
|
+
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
68
|
+
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
69
|
+
batch_next_token_ids, success = min_p_sampling_from_probs(
|
70
|
+
probs, uniform_samples, sampling_info.min_ps
|
71
|
+
)
|
72
|
+
else:
|
73
|
+
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
74
|
+
probs,
|
75
|
+
uniform_samples,
|
76
|
+
sampling_info.top_ks,
|
77
|
+
sampling_info.top_ps,
|
78
|
+
filter_apply_order="joint",
|
79
|
+
)
|
80
|
+
|
81
|
+
if not torch.all(success):
|
82
|
+
logger.warning("Detected errors during sampling!")
|
83
|
+
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
84
|
+
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
85
|
+
# A slower fallback implementation with torch native operations.
|
86
|
+
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
65
87
|
probs,
|
66
|
-
uniform_samples,
|
67
88
|
sampling_info.top_ks,
|
68
89
|
sampling_info.top_ps,
|
69
|
-
|
90
|
+
sampling_info.min_ps,
|
91
|
+
)
|
92
|
+
else:
|
93
|
+
raise ValueError(
|
94
|
+
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
70
95
|
)
|
71
96
|
|
72
|
-
|
73
|
-
logger.warning("Detected errors during sampling!")
|
74
|
-
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
75
|
-
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
76
|
-
# Here we provide a slower fallback implementation.
|
77
|
-
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
78
|
-
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
79
|
-
)
|
80
|
-
else:
|
81
|
-
raise ValueError(
|
82
|
-
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
83
|
-
)
|
84
|
-
|
85
|
-
return batch_next_token_ids
|
97
|
+
return batch_next_token_ids.to(torch.int32)
|
86
98
|
|
87
99
|
|
88
100
|
def top_k_top_p_min_p_sampling_from_probs_torch(
|