sglang 0.4.1.post6__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +24 -16
- sglang/bench_one_batch.py +51 -3
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +37 -28
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +15 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +449 -0
- sglang/srt/entrypoints/http_server.py +579 -0
- sglang/srt/layers/activation.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +10 -9
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +69 -0
- sglang/srt/layers/linear.py +41 -5
- sglang/srt/layers/logits_processor.py +48 -63
- sglang/srt/layers/moe/ep_moe/layer.py +4 -4
- sglang/srt/layers/moe/fused_moe_native.py +69 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
- sglang/srt/layers/moe/fused_moe_triton/layer.py +29 -5
- sglang/srt/layers/parameter.py +2 -1
- sglang/srt/layers/quantization/__init__.py +20 -23
- sglang/srt/layers/quantization/fp8.py +6 -3
- sglang/srt/layers/quantization/modelopt_quant.py +1 -2
- sglang/srt/layers/quantization/w8a8_int8.py +1 -1
- sglang/srt/layers/radix_attention.py +2 -2
- sglang/srt/layers/rotary_embedding.py +1179 -31
- sglang/srt/layers/sampler.py +39 -1
- sglang/srt/layers/vocab_parallel_embedding.py +2 -2
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +3 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +23 -6
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +25 -2
- sglang/srt/managers/schedule_batch.py +49 -22
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +277 -178
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +206 -121
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +10 -32
- sglang/srt/metrics/collector.py +15 -6
- sglang/srt/model_executor/cuda_graph_runner.py +4 -6
- sglang/srt/model_executor/model_runner.py +37 -15
- sglang/srt/model_loader/loader.py +8 -6
- sglang/srt/model_loader/weight_utils.py +55 -2
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +3 -3
- sglang/srt/models/dbrx.py +4 -4
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +8 -8
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +6 -24
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +7 -5
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +6 -6
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +41 -4
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +3 -3
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/sampling_batch_info.py +139 -4
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +4 -1090
- sglang/srt/server_args.py +57 -14
- sglang/srt/utils.py +103 -65
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +16 -5
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +119 -115
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
sglang/srt/layers/sampler.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1
1
|
import logging
|
2
|
-
from typing import List
|
2
|
+
from typing import Dict, List
|
3
3
|
|
4
4
|
import torch
|
5
5
|
from torch import nn
|
6
6
|
|
7
7
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
8
8
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
9
|
+
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
9
10
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
10
11
|
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
|
11
12
|
|
@@ -35,6 +36,10 @@ class Sampler(nn.Module):
|
|
35
36
|
):
|
36
37
|
logits = logits_output.next_token_logits
|
37
38
|
|
39
|
+
# Apply the custom logit processors if registered in the sampling info.
|
40
|
+
if sampling_info.has_custom_logit_processor:
|
41
|
+
self._apply_custom_logit_processor(logits, sampling_info)
|
42
|
+
|
38
43
|
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
|
39
44
|
logger.warning("Detected errors during sampling! NaN in the logits.")
|
40
45
|
logits = torch.where(
|
@@ -121,6 +126,39 @@ class Sampler(nn.Module):
|
|
121
126
|
|
122
127
|
return batch_next_token_ids
|
123
128
|
|
129
|
+
def _apply_custom_logit_processor(
|
130
|
+
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
|
131
|
+
):
|
132
|
+
"""Apply custom logit processors to the logits.
|
133
|
+
This function will modify the logits in-place."""
|
134
|
+
|
135
|
+
assert logits.shape[0] == len(sampling_batch_info), (
|
136
|
+
f"The batch size of logits ({logits.shape[0]}) does not match the batch size of "
|
137
|
+
f"sampling_batch_info ({len(sampling_batch_info)})"
|
138
|
+
)
|
139
|
+
|
140
|
+
for _, (
|
141
|
+
processor,
|
142
|
+
batch_mask,
|
143
|
+
) in sampling_batch_info.custom_logit_processor.items():
|
144
|
+
# Get the batch indices that need to be processed
|
145
|
+
batch_indices = batch_mask.nonzero(as_tuple=True)[0]
|
146
|
+
|
147
|
+
assert batch_mask.shape[0] == len(sampling_batch_info), (
|
148
|
+
f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of "
|
149
|
+
f"sampling_batch_info ({len(sampling_batch_info)})"
|
150
|
+
)
|
151
|
+
|
152
|
+
# Apply the processor to the logits
|
153
|
+
logits[batch_mask] = processor(
|
154
|
+
logits[batch_mask],
|
155
|
+
[sampling_batch_info.custom_params[i] for i in batch_indices],
|
156
|
+
)
|
157
|
+
|
158
|
+
logger.debug(
|
159
|
+
f"Custom logit processor {processor.__class__.__name__} is applied."
|
160
|
+
)
|
161
|
+
|
124
162
|
|
125
163
|
def top_k_top_p_min_p_sampling_from_probs_torch(
|
126
164
|
probs: torch.Tensor,
|
@@ -6,13 +6,13 @@ from typing import List, Optional, Sequence, Tuple
|
|
6
6
|
import torch
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn.parameter import Parameter, UninitializedParameter
|
9
|
-
|
9
|
+
|
10
|
+
from sglang.srt.distributed import (
|
10
11
|
divide,
|
11
12
|
get_tensor_model_parallel_rank,
|
12
13
|
get_tensor_model_parallel_world_size,
|
13
14
|
tensor_model_parallel_all_reduce,
|
14
15
|
)
|
15
|
-
|
16
16
|
from sglang.srt.layers.parameter import BasevLLMParameter
|
17
17
|
from sglang.srt.layers.quantization.base_config import (
|
18
18
|
QuantizationConfig,
|
sglang/srt/lora/lora.py
CHANGED
@@ -19,18 +19,11 @@
|
|
19
19
|
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
|
20
20
|
|
21
21
|
|
22
|
-
import json
|
23
|
-
import os
|
24
22
|
import re
|
25
|
-
from typing import Any, Dict, List, Optional, Tuple
|
26
23
|
|
27
|
-
import safetensors.torch
|
28
24
|
import torch
|
29
25
|
from torch import nn
|
30
|
-
from vllm.model_executor.layers.vocab_parallel_embedding import
|
31
|
-
ParallelLMHead,
|
32
|
-
VocabParallelEmbedding,
|
33
|
-
)
|
26
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
34
27
|
|
35
28
|
from sglang.srt.layers.linear import (
|
36
29
|
ColumnParallelLinear,
|
@@ -38,7 +31,6 @@ from sglang.srt.layers.linear import (
|
|
38
31
|
QKVParallelLinear,
|
39
32
|
RowParallelLinear,
|
40
33
|
)
|
41
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
42
34
|
from sglang.srt.model_loader.loader import DefaultModelLoader
|
43
35
|
|
44
36
|
|
@@ -27,6 +27,7 @@ import requests
|
|
27
27
|
if __name__ == "__main__":
|
28
28
|
parser = argparse.ArgumentParser()
|
29
29
|
parser.add_argument("--url", type=str, default="http://localhost:30000")
|
30
|
+
parser.add_argument("--log-requests", action="store_true")
|
30
31
|
parser.add_argument(
|
31
32
|
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
|
32
33
|
)
|
@@ -36,6 +37,8 @@ if __name__ == "__main__":
|
|
36
37
|
response = requests.post(
|
37
38
|
args.url + "/configure_logging",
|
38
39
|
json={
|
40
|
+
"log_requests": args.log_requests,
|
41
|
+
"log_requests_level": 1, # Log full requests
|
39
42
|
"dump_requests_folder": args.dump_requests_folder,
|
40
43
|
"dump_requests_threshold": args.dump_requests_threshold,
|
41
44
|
},
|
@@ -23,6 +23,7 @@ import psutil
|
|
23
23
|
import setproctitle
|
24
24
|
import zmq
|
25
25
|
|
26
|
+
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
26
27
|
from sglang.srt.managers.io_struct import (
|
27
28
|
TokenizedEmbeddingReqInput,
|
28
29
|
TokenizedGenerateReqInput,
|
@@ -55,6 +56,7 @@ class DataParallelController:
|
|
55
56
|
|
56
57
|
def __init__(self, server_args, port_args) -> None:
|
57
58
|
# Parse args
|
59
|
+
self.max_total_num_tokens = None
|
58
60
|
self.server_args = server_args
|
59
61
|
self.port_args = port_args
|
60
62
|
self.load_balance_method = LoadBalanceMethod.from_str(
|
@@ -63,9 +65,10 @@ class DataParallelController:
|
|
63
65
|
|
64
66
|
# Init inter-process communication
|
65
67
|
self.context = zmq.Context(1 + server_args.dp_size)
|
66
|
-
|
67
|
-
self.
|
68
|
-
|
68
|
+
if server_args.node_rank == 0:
|
69
|
+
self.recv_from_tokenizer = get_zmq_socket(
|
70
|
+
self.context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
71
|
+
)
|
69
72
|
|
70
73
|
# Dispatch method
|
71
74
|
self.round_robin_counter = 0
|
@@ -75,33 +78,50 @@ class DataParallelController:
|
|
75
78
|
}
|
76
79
|
self.dispatching = dispatch_lookup[self.load_balance_method]
|
77
80
|
|
78
|
-
#
|
79
|
-
|
81
|
+
# Launch data parallel workers
|
82
|
+
self.scheduler_procs = []
|
80
83
|
self.workers = [None] * server_args.dp_size
|
81
84
|
|
85
|
+
if not server_args.enable_dp_attention:
|
86
|
+
dp_port_args = self.launch_dp_schedulers(server_args, port_args)
|
87
|
+
else:
|
88
|
+
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
|
89
|
+
|
90
|
+
# Only node rank 0 runs the real data parallel controller that dispatches the requests.
|
91
|
+
if server_args.node_rank == 0:
|
92
|
+
for dp_rank in range(server_args.dp_size):
|
93
|
+
self.workers[dp_rank] = get_zmq_socket(
|
94
|
+
self.context,
|
95
|
+
zmq.PUSH,
|
96
|
+
dp_port_args[dp_rank].scheduler_input_ipc_name,
|
97
|
+
True,
|
98
|
+
)
|
99
|
+
|
100
|
+
self.max_req_input_len = None
|
101
|
+
|
102
|
+
def launch_dp_schedulers(self, server_args, port_args):
|
103
|
+
base_gpu_id = 0
|
104
|
+
|
82
105
|
threads = []
|
83
106
|
sockets = []
|
107
|
+
dp_port_args = []
|
84
108
|
for dp_rank in range(server_args.dp_size):
|
85
109
|
tmp_port_args = PortArgs.init_new(server_args)
|
86
110
|
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
|
87
111
|
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
|
112
|
+
dp_port_args.append(tmp_port_args)
|
88
113
|
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
tmp_port_args.nccl_port = port_args.nccl_port
|
93
|
-
else:
|
94
|
-
# This port is checked free in PortArgs.init_new.
|
95
|
-
# We hold it first so that the next dp worker gets a different port
|
96
|
-
sockets.append(bind_port(tmp_port_args.nccl_port))
|
114
|
+
# This port is checked free in PortArgs.init_new.
|
115
|
+
# We hold it first so that the next dp worker gets a different port
|
116
|
+
sockets.append(bind_port(tmp_port_args.nccl_port))
|
97
117
|
|
98
118
|
# Create a thread for each worker
|
99
119
|
thread = threading.Thread(
|
100
|
-
target=self.
|
120
|
+
target=self.launch_tensor_parallel_group,
|
101
121
|
args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
|
102
122
|
)
|
103
123
|
threads.append(thread)
|
104
|
-
base_gpu_id +=
|
124
|
+
base_gpu_id += server_args.tp_size
|
105
125
|
|
106
126
|
# Free all sockets before starting the threads to launch TP workers
|
107
127
|
for sock in sockets:
|
@@ -113,26 +133,14 @@ class DataParallelController:
|
|
113
133
|
for thread in threads:
|
114
134
|
thread.join()
|
115
135
|
|
116
|
-
|
117
|
-
self,
|
118
|
-
server_args: ServerArgs,
|
119
|
-
port_args: PortArgs,
|
120
|
-
base_gpu_id: int,
|
121
|
-
dp_rank: int,
|
122
|
-
):
|
123
|
-
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
|
136
|
+
return dp_port_args
|
124
137
|
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
server_args,
|
132
|
-
port_args,
|
133
|
-
base_gpu_id,
|
134
|
-
dp_rank,
|
135
|
-
)
|
138
|
+
def launch_dp_attention_schedulers(self, server_args, port_args):
|
139
|
+
self.launch_tensor_parallel_group(server_args, port_args, 0, None)
|
140
|
+
dp_port_args = []
|
141
|
+
for dp_rank in range(server_args.dp_size):
|
142
|
+
dp_port_args.append(PortArgs.init_new(server_args, dp_rank))
|
143
|
+
return dp_port_args
|
136
144
|
|
137
145
|
def launch_tensor_parallel_group(
|
138
146
|
self,
|
@@ -141,8 +149,10 @@ class DataParallelController:
|
|
141
149
|
base_gpu_id: int,
|
142
150
|
dp_rank: int,
|
143
151
|
):
|
152
|
+
if not server_args.enable_dp_attention:
|
153
|
+
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
|
154
|
+
|
144
155
|
# Launch tensor parallel scheduler processes
|
145
|
-
scheduler_procs = []
|
146
156
|
scheduler_pipe_readers = []
|
147
157
|
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
148
158
|
tp_rank_range = range(
|
@@ -150,52 +160,39 @@ class DataParallelController:
|
|
150
160
|
tp_size_per_node * (server_args.node_rank + 1),
|
151
161
|
)
|
152
162
|
for tp_rank in tp_rank_range:
|
163
|
+
rank_port_args = port_args
|
164
|
+
|
165
|
+
if server_args.enable_dp_attention:
|
166
|
+
# dp attention has different sharding logic
|
167
|
+
_, _, dp_rank = compute_dp_attention_world_info(
|
168
|
+
server_args.enable_dp_attention,
|
169
|
+
tp_rank,
|
170
|
+
server_args.tp_size,
|
171
|
+
server_args.dp_size,
|
172
|
+
)
|
173
|
+
# compute zmq ports for this dp rank
|
174
|
+
rank_port_args = PortArgs.init_new(server_args, dp_rank)
|
175
|
+
# Data parallelism resues the tensor parallelism group,
|
176
|
+
# so all dp ranks should use the same nccl port.
|
177
|
+
rank_port_args.nccl_port = port_args.nccl_port
|
178
|
+
|
153
179
|
reader, writer = mp.Pipe(duplex=False)
|
154
180
|
gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
|
155
181
|
proc = mp.Process(
|
156
182
|
target=run_scheduler_process,
|
157
|
-
args=(server_args,
|
183
|
+
args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
|
158
184
|
)
|
159
185
|
proc.start()
|
160
|
-
scheduler_procs.append(proc)
|
186
|
+
self.scheduler_procs.append(proc)
|
161
187
|
scheduler_pipe_readers.append(reader)
|
162
188
|
|
163
|
-
|
164
|
-
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
165
|
-
)
|
166
|
-
|
167
|
-
# Wait for model to finish loading and get max token nums
|
189
|
+
# Wait for model to finish loading
|
168
190
|
scheduler_info = []
|
169
191
|
for i in range(len(scheduler_pipe_readers)):
|
170
192
|
scheduler_info.append(scheduler_pipe_readers[i].recv())
|
171
193
|
|
172
194
|
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
173
|
-
|
174
|
-
return send_to
|
175
|
-
|
176
|
-
def launch_tensor_parallel_process(
|
177
|
-
self,
|
178
|
-
server_args: ServerArgs,
|
179
|
-
port_args: PortArgs,
|
180
|
-
base_gpu_id: int,
|
181
|
-
dp_rank: int,
|
182
|
-
):
|
183
|
-
reader, writer = mp.Pipe(duplex=False)
|
184
|
-
gpu_id = base_gpu_id
|
185
|
-
tp_rank = dp_rank
|
186
|
-
proc = mp.Process(
|
187
|
-
target=run_scheduler_process,
|
188
|
-
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
|
189
|
-
)
|
190
|
-
proc.start()
|
191
|
-
send_to = get_zmq_socket(
|
192
|
-
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
193
|
-
)
|
194
|
-
|
195
|
-
scheduler_info = reader.recv()
|
196
|
-
self.max_total_num_tokens = scheduler_info["max_total_num_tokens"]
|
197
|
-
|
198
|
-
return send_to
|
195
|
+
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
|
199
196
|
|
200
197
|
def round_robin_scheduler(self, req):
|
201
198
|
self.workers[self.round_robin_counter].send_pyobj(req)
|
@@ -221,8 +218,8 @@ class DataParallelController:
|
|
221
218
|
):
|
222
219
|
self.dispatching(recv_req)
|
223
220
|
else:
|
224
|
-
# Send other control messages to
|
225
|
-
for worker in self.workers:
|
221
|
+
# Send other control messages to first worker of tp group
|
222
|
+
for worker in self.workers[:: self.server_args.tp_size]:
|
226
223
|
worker.send_pyobj(recv_req)
|
227
224
|
|
228
225
|
|
@@ -238,9 +235,19 @@ def run_data_parallel_controller_process(
|
|
238
235
|
try:
|
239
236
|
controller = DataParallelController(server_args, port_args)
|
240
237
|
pipe_writer.send(
|
241
|
-
{
|
238
|
+
{
|
239
|
+
"status": "ready",
|
240
|
+
"max_total_num_tokens": controller.max_total_num_tokens,
|
241
|
+
"max_req_input_len": controller.max_req_input_len,
|
242
|
+
}
|
242
243
|
)
|
243
|
-
|
244
|
+
if server_args.node_rank == 0:
|
245
|
+
controller.event_loop()
|
246
|
+
for proc in controller.scheduler_procs:
|
247
|
+
proc.join()
|
248
|
+
logger.error(
|
249
|
+
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
|
250
|
+
)
|
244
251
|
except Exception:
|
245
252
|
traceback = get_exception_traceback()
|
246
253
|
logger.error(f"DataParallelController hit an exception: {traceback}")
|
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
import dataclasses
|
17
17
|
import logging
|
18
|
+
import os
|
18
19
|
import signal
|
19
20
|
from collections import OrderedDict
|
20
21
|
from typing import Dict, List, Union
|
@@ -35,6 +36,12 @@ from sglang.utils import find_printable_text, get_exception_traceback
|
|
35
36
|
|
36
37
|
logger = logging.getLogger(__name__)
|
37
38
|
|
39
|
+
# Maximum number of request states that detokenizer can hold. When exceeded,
|
40
|
+
# oldest request states will be evicted. Default: 65536 (1<<16).
|
41
|
+
# For more details, see: https://github.com/sgl-project/sglang/issues/2812
|
42
|
+
# Use power of 2 values for better memory allocation.
|
43
|
+
DETOKENIZER_MAX_STATES = int(os.environ.get("SGLANG_DETOKENIZER_MAX_STATES", 1 << 16))
|
44
|
+
|
38
45
|
|
39
46
|
@dataclasses.dataclass
|
40
47
|
class DecodeStatus:
|
@@ -58,10 +65,10 @@ class DetokenizerManager:
|
|
58
65
|
# Init inter-process communication
|
59
66
|
context = zmq.Context(2)
|
60
67
|
self.recv_from_scheduler = get_zmq_socket(
|
61
|
-
context, zmq.PULL, port_args.detokenizer_ipc_name
|
68
|
+
context, zmq.PULL, port_args.detokenizer_ipc_name, True
|
62
69
|
)
|
63
70
|
self.send_to_tokenizer = get_zmq_socket(
|
64
|
-
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
71
|
+
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
65
72
|
)
|
66
73
|
|
67
74
|
if server_args.skip_tokenizer_init:
|
@@ -71,9 +78,10 @@ class DetokenizerManager:
|
|
71
78
|
server_args.tokenizer_path,
|
72
79
|
tokenizer_mode=server_args.tokenizer_mode,
|
73
80
|
trust_remote_code=server_args.trust_remote_code,
|
81
|
+
revision=server_args.revision,
|
74
82
|
)
|
75
83
|
|
76
|
-
self.decode_status = LimitedCapacityDict()
|
84
|
+
self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES)
|
77
85
|
|
78
86
|
def trim_matched_stop(
|
79
87
|
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
|
@@ -155,7 +163,17 @@ class DetokenizerManager:
|
|
155
163
|
# Incremental decoding
|
156
164
|
output_strs = []
|
157
165
|
for i in range(bs):
|
158
|
-
|
166
|
+
try:
|
167
|
+
s = self.decode_status[recv_obj.rids[i]]
|
168
|
+
except KeyError:
|
169
|
+
raise RuntimeError(
|
170
|
+
f"Decode status not found for request {recv_obj.rids[i]}. "
|
171
|
+
"It may be due to the request being evicted from the decode status due to memory pressure. "
|
172
|
+
"Please increase the maximum number of requests by setting "
|
173
|
+
"the SGLANG_DETOKENIZER_MAX_STATES environment variable to a bigger value than the default value. "
|
174
|
+
f"The current value is {DETOKENIZER_MAX_STATES}. "
|
175
|
+
"For more details, see: https://github.com/sgl-project/sglang/issues/2812"
|
176
|
+
)
|
159
177
|
new_text = read_texts[i][len(surr_texts[i]) :]
|
160
178
|
if recv_obj.finished_reasons[i] is None:
|
161
179
|
# Streaming chunk: update the decode status
|
@@ -191,13 +209,12 @@ class DetokenizerManager:
|
|
191
209
|
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
|
192
210
|
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
|
193
211
|
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
|
194
|
-
normalized_prompt_logprob=recv_obj.normalized_prompt_logprob,
|
195
212
|
)
|
196
213
|
)
|
197
214
|
|
198
215
|
|
199
216
|
class LimitedCapacityDict(OrderedDict):
|
200
|
-
def __init__(self, capacity
|
217
|
+
def __init__(self, capacity: int, *args, **kwargs):
|
201
218
|
super().__init__(*args, **kwargs)
|
202
219
|
self.capacity = capacity
|
203
220
|
|
@@ -9,6 +9,8 @@ from typing import List, Optional, Union
|
|
9
9
|
|
10
10
|
import numpy as np
|
11
11
|
import transformers
|
12
|
+
from decord import VideoReader, cpu
|
13
|
+
from PIL import Image
|
12
14
|
|
13
15
|
from sglang.srt.hf_transformers_utils import get_processor
|
14
16
|
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
@@ -36,6 +38,7 @@ class BaseImageProcessor(ABC):
|
|
36
38
|
def __init__(self, hf_config, server_args, _processor):
|
37
39
|
self.hf_config = hf_config
|
38
40
|
self._processor = _processor
|
41
|
+
self.server_args = server_args
|
39
42
|
|
40
43
|
self.executor = concurrent.futures.ProcessPoolExecutor(
|
41
44
|
initializer=init_global_processor,
|
@@ -126,7 +129,12 @@ class LlavaImageProcessor(BaseImageProcessor):
|
|
126
129
|
)
|
127
130
|
|
128
131
|
async def process_images_async(
|
129
|
-
self,
|
132
|
+
self,
|
133
|
+
image_data: List[Union[str, bytes]],
|
134
|
+
input_text,
|
135
|
+
request_obj,
|
136
|
+
*args,
|
137
|
+
**kwargs,
|
130
138
|
):
|
131
139
|
if not image_data:
|
132
140
|
return None
|
@@ -229,6 +237,147 @@ class MllamaImageProcessor(BaseImageProcessor):
|
|
229
237
|
return image_inputs
|
230
238
|
|
231
239
|
|
240
|
+
class MiniCPMVImageProcessor(BaseImageProcessor):
|
241
|
+
def __init__(self, hf_config, server_args, _processor):
|
242
|
+
super().__init__(hf_config, server_args, _processor)
|
243
|
+
|
244
|
+
@staticmethod
|
245
|
+
def _process_images_task(images, input_text):
|
246
|
+
result = global_processor.__call__(
|
247
|
+
text=input_text, images=images, return_tensors="pt"
|
248
|
+
)
|
249
|
+
return {
|
250
|
+
"input_ids": result["input_ids"],
|
251
|
+
"pixel_values": result["pixel_values"],
|
252
|
+
"tgt_sizes": result["tgt_sizes"],
|
253
|
+
}
|
254
|
+
|
255
|
+
async def _process_images(self, images, input_text):
|
256
|
+
if self.executor is not None:
|
257
|
+
loop = asyncio.get_event_loop()
|
258
|
+
image_inputs = await loop.run_in_executor(
|
259
|
+
self.executor,
|
260
|
+
MiniCPMVImageProcessor._process_images_task,
|
261
|
+
images,
|
262
|
+
input_text,
|
263
|
+
)
|
264
|
+
else:
|
265
|
+
image_inputs = self._processor(
|
266
|
+
images=images, text=input_text, return_tensors="pt"
|
267
|
+
)
|
268
|
+
|
269
|
+
return image_inputs
|
270
|
+
|
271
|
+
async def process_images_async(
|
272
|
+
self,
|
273
|
+
image_data: List[Union[str, bytes]],
|
274
|
+
input_text,
|
275
|
+
request_obj,
|
276
|
+
max_req_input_len,
|
277
|
+
):
|
278
|
+
if not image_data:
|
279
|
+
return None
|
280
|
+
|
281
|
+
if not isinstance(image_data, list):
|
282
|
+
image_data = [image_data]
|
283
|
+
|
284
|
+
image_hashes, image_sizes = [], []
|
285
|
+
raw_images = []
|
286
|
+
IMAGE_TOKEN = "(<image>./</image>)"
|
287
|
+
|
288
|
+
# roughly calculate the max number of frames
|
289
|
+
# TODO: the process should be applied to all the visual inputs
|
290
|
+
def calculate_max_num_frames() -> int:
|
291
|
+
# Model-specific
|
292
|
+
NUM_TOKEN_PER_FRAME = 330
|
293
|
+
|
294
|
+
ret = (max_req_input_len - len(input_text)) // NUM_TOKEN_PER_FRAME
|
295
|
+
return min(ret, 100)
|
296
|
+
|
297
|
+
# if cuda OOM set a smaller number
|
298
|
+
MAX_NUM_FRAMES = calculate_max_num_frames()
|
299
|
+
print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
|
300
|
+
|
301
|
+
def encode_video(video_path):
|
302
|
+
if not os.path.exists(video_path):
|
303
|
+
logger.error(f"Video {video_path} does not exist")
|
304
|
+
return []
|
305
|
+
|
306
|
+
if MAX_NUM_FRAMES == 0:
|
307
|
+
return []
|
308
|
+
|
309
|
+
def uniform_sample(l, n):
|
310
|
+
gap = len(l) / n
|
311
|
+
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
312
|
+
return [l[i] for i in idxs]
|
313
|
+
|
314
|
+
vr = VideoReader(video_path, ctx=cpu(0))
|
315
|
+
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
316
|
+
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
317
|
+
if len(frame_idx) > MAX_NUM_FRAMES:
|
318
|
+
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
|
319
|
+
frames = vr.get_batch(frame_idx).asnumpy()
|
320
|
+
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
321
|
+
return frames
|
322
|
+
|
323
|
+
if isinstance(input_text, list):
|
324
|
+
assert len(input_text) and isinstance(input_text[0], int)
|
325
|
+
input_text = self._processor.tokenizer.decode(input_text)
|
326
|
+
|
327
|
+
# MiniCPMV requires each frame of video as a single image token
|
328
|
+
text_parts = input_text.split(IMAGE_TOKEN)
|
329
|
+
new_text_parts = []
|
330
|
+
|
331
|
+
for image_index, image in enumerate(image_data):
|
332
|
+
try:
|
333
|
+
if isinstance(image, str) and image.startswith("video:"):
|
334
|
+
path = image[len("video:") :]
|
335
|
+
frames = encode_video(path)
|
336
|
+
else:
|
337
|
+
raw_image, size = load_image(image)
|
338
|
+
frames = [raw_image]
|
339
|
+
if len(frames) == 0:
|
340
|
+
continue
|
341
|
+
except FileNotFoundError as e:
|
342
|
+
print(e)
|
343
|
+
return None
|
344
|
+
|
345
|
+
image_sizes += frames[0].size * len(frames)
|
346
|
+
image_hashes += [hash(image)] * len(frames)
|
347
|
+
raw_images += frames
|
348
|
+
new_text_parts.append(text_parts[image_index])
|
349
|
+
new_text_parts.append(IMAGE_TOKEN * len(frames))
|
350
|
+
|
351
|
+
new_text_parts.append(text_parts[-1])
|
352
|
+
input_text = "".join(new_text_parts)
|
353
|
+
if len(raw_images) == 0:
|
354
|
+
return None
|
355
|
+
res = await self._process_images(images=raw_images, input_text=input_text)
|
356
|
+
pixel_values = res["pixel_values"]
|
357
|
+
tgt_sizes = res["tgt_sizes"]
|
358
|
+
input_ids = res["input_ids"]
|
359
|
+
|
360
|
+
# Collect special token ids
|
361
|
+
tokenizer = self._processor.tokenizer
|
362
|
+
im_start_id = [tokenizer.im_start_id]
|
363
|
+
im_end_id = [tokenizer.im_end_id]
|
364
|
+
if tokenizer.slice_start_id:
|
365
|
+
slice_start_id = [tokenizer.slice_start_id]
|
366
|
+
slice_end_id = [tokenizer.slice_end_id]
|
367
|
+
|
368
|
+
return {
|
369
|
+
"input_ids": input_ids.flatten().tolist(),
|
370
|
+
"pixel_values": pixel_values,
|
371
|
+
"tgt_sizes": tgt_sizes,
|
372
|
+
"image_hashes": image_hashes,
|
373
|
+
"modalities": request_obj.modalities or ["image"],
|
374
|
+
"im_start_id": im_start_id,
|
375
|
+
"im_end_id": im_end_id,
|
376
|
+
"slice_start_id": slice_start_id,
|
377
|
+
"slice_end_id": slice_end_id,
|
378
|
+
}
|
379
|
+
|
380
|
+
|
232
381
|
class Qwen2VLImageProcessor(BaseImageProcessor):
|
233
382
|
def __init__(self, hf_config, server_args, _image_processor):
|
234
383
|
self.hf_config = hf_config
|
@@ -289,7 +438,12 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
|
289
438
|
return self._process_single_image_task(image_data)
|
290
439
|
|
291
440
|
async def process_images_async(
|
292
|
-
self,
|
441
|
+
self,
|
442
|
+
image_data: List[Union[str, bytes]],
|
443
|
+
input_text,
|
444
|
+
request_obj,
|
445
|
+
*args,
|
446
|
+
**kwargs,
|
293
447
|
):
|
294
448
|
if not image_data:
|
295
449
|
return None
|
@@ -350,6 +504,8 @@ def get_image_processor(
|
|
350
504
|
return MllamaImageProcessor(hf_config, server_args, processor)
|
351
505
|
elif "Qwen2VLForConditionalGeneration" in hf_config.architectures:
|
352
506
|
return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor)
|
507
|
+
elif "MiniCPMV" in hf_config.architectures:
|
508
|
+
return MiniCPMVImageProcessor(hf_config, server_args, processor)
|
353
509
|
else:
|
354
510
|
return LlavaImageProcessor(hf_config, server_args, processor.image_processor)
|
355
511
|
|