sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +41 -27
- sglang/bench_one_batch.py +60 -4
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +83 -71
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +46 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +452 -0
- sglang/srt/entrypoints/http_server.py +603 -0
- sglang/srt/function_call_parser.py +494 -0
- sglang/srt/layers/activation.py +8 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -9
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +71 -0
- sglang/srt/layers/layernorm.py +5 -5
- sglang/srt/layers/linear.py +65 -14
- sglang/srt/layers/logits_processor.py +49 -64
- sglang/srt/layers/moe/ep_moe/layer.py +24 -16
- sglang/srt/layers/moe/fused_moe_native.py +84 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
- sglang/srt/layers/parameter.py +18 -8
- sglang/srt/layers/quantization/__init__.py +20 -23
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/fp8.py +10 -4
- sglang/srt/layers/quantization/modelopt_quant.py +1 -2
- sglang/srt/layers/quantization/w8a8_int8.py +1 -1
- sglang/srt/layers/radix_attention.py +2 -2
- sglang/srt/layers/rotary_embedding.py +1184 -31
- sglang/srt/layers/sampler.py +64 -6
- sglang/srt/layers/torchao_utils.py +12 -6
- sglang/srt/layers/vocab_parallel_embedding.py +2 -2
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +3 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +24 -6
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +57 -3
- sglang/srt/managers/schedule_batch.py +78 -45
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +326 -201
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +210 -121
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +10 -32
- sglang/srt/metrics/collector.py +15 -6
- sglang/srt/model_executor/cuda_graph_runner.py +26 -30
- sglang/srt/model_executor/forward_batch_info.py +5 -7
- sglang/srt/model_executor/model_runner.py +44 -19
- sglang/srt/model_loader/loader.py +83 -6
- sglang/srt/model_loader/weight_utils.py +145 -6
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +17 -5
- sglang/srt/models/dbrx.py +13 -5
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +11 -11
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +15 -25
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +4 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +7 -5
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +9 -9
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +41 -4
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +20 -7
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/adapter.py +139 -37
- sglang/srt/openai_api/protocol.py +7 -4
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
- sglang/srt/sampling/sampling_batch_info.py +143 -18
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +4 -1090
- sglang/srt/server_args.py +77 -15
- sglang/srt/speculative/eagle_utils.py +37 -15
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/utils.py +164 -129
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +2 -1
- sglang/test/test_utils.py +83 -22
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
sglang/srt/layers/sampler.py
CHANGED
@@ -2,15 +2,18 @@ import logging
|
|
2
2
|
from typing import List
|
3
3
|
|
4
4
|
import torch
|
5
|
+
import torch.distributed as dist
|
5
6
|
from torch import nn
|
6
7
|
|
8
|
+
from sglang.srt.distributed import get_tensor_model_parallel_group
|
9
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_group
|
7
10
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
8
11
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
9
12
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
10
|
-
from sglang.srt.utils import crash_on_warnings,
|
13
|
+
from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda_available
|
11
14
|
|
12
|
-
if
|
13
|
-
from
|
15
|
+
if is_cuda_available():
|
16
|
+
from sgl_kernel import (
|
14
17
|
min_p_sampling_from_probs,
|
15
18
|
top_k_renorm_prob,
|
16
19
|
top_k_top_p_sampling_from_probs,
|
@@ -20,11 +23,17 @@ if is_flashinfer_available():
|
|
20
23
|
|
21
24
|
logger = logging.getLogger(__name__)
|
22
25
|
|
26
|
+
SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
|
27
|
+
|
23
28
|
|
24
29
|
class Sampler(nn.Module):
|
25
30
|
def __init__(self):
|
26
31
|
super().__init__()
|
27
32
|
self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"]
|
33
|
+
self.tp_sync_group = get_tensor_model_parallel_group().device_group
|
34
|
+
|
35
|
+
if global_server_args_dict["enable_dp_attention"]:
|
36
|
+
self.tp_sync_group = get_attention_tp_group().device_group
|
28
37
|
|
29
38
|
def forward(
|
30
39
|
self,
|
@@ -35,6 +44,10 @@ class Sampler(nn.Module):
|
|
35
44
|
):
|
36
45
|
logits = logits_output.next_token_logits
|
37
46
|
|
47
|
+
# Apply the custom logit processors if registered in the sampling info.
|
48
|
+
if sampling_info.has_custom_logit_processor:
|
49
|
+
self._apply_custom_logit_processor(logits, sampling_info)
|
50
|
+
|
38
51
|
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
|
39
52
|
logger.warning("Detected errors during sampling! NaN in the logits.")
|
40
53
|
logits = torch.where(
|
@@ -104,8 +117,6 @@ class Sampler(nn.Module):
|
|
104
117
|
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
105
118
|
)
|
106
119
|
|
107
|
-
batch_next_token_ids = batch_next_token_ids.to(torch.int32)
|
108
|
-
|
109
120
|
# Attach logprobs to logits_output (in-place modification)
|
110
121
|
if return_logprob:
|
111
122
|
if any(x > 0 for x in top_logprobs_nums):
|
@@ -119,7 +130,54 @@ class Sampler(nn.Module):
|
|
119
130
|
batch_next_token_ids,
|
120
131
|
]
|
121
132
|
|
122
|
-
|
133
|
+
if SYNC_TOKEN_IDS_ACROSS_TP or sampling_info.grammars:
|
134
|
+
# For performance reasons, SGLang does not sync the final token IDs across TP ranks by default.
|
135
|
+
# This saves one all-reduce, but the correctness of this approach depends on the determinism of several operators:
|
136
|
+
# the last all-reduce, the last lm_head matmul, and all sampling kernels.
|
137
|
+
# These kernels are deterministic in most cases, but there are some rare instances where they are not deterministic.
|
138
|
+
# In such cases, enable this env variable to prevent hanging due to TP ranks becoming desynchronized.
|
139
|
+
# When using xgrammar, this becomes more likely so we also do the sync when grammar is used.
|
140
|
+
|
141
|
+
torch.distributed.all_reduce(
|
142
|
+
batch_next_token_ids,
|
143
|
+
op=dist.ReduceOp.MIN,
|
144
|
+
group=self.tp_sync_group,
|
145
|
+
)
|
146
|
+
|
147
|
+
return batch_next_token_ids.to(torch.int32)
|
148
|
+
|
149
|
+
def _apply_custom_logit_processor(
|
150
|
+
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
|
151
|
+
):
|
152
|
+
"""Apply custom logit processors to the logits.
|
153
|
+
This function will modify the logits in-place."""
|
154
|
+
|
155
|
+
assert logits.shape[0] == len(sampling_batch_info), (
|
156
|
+
f"The batch size of logits ({logits.shape[0]}) does not match the batch size of "
|
157
|
+
f"sampling_batch_info ({len(sampling_batch_info)})"
|
158
|
+
)
|
159
|
+
|
160
|
+
for _, (
|
161
|
+
processor,
|
162
|
+
batch_mask,
|
163
|
+
) in sampling_batch_info.custom_logit_processor.items():
|
164
|
+
# Get the batch indices that need to be processed
|
165
|
+
batch_indices = batch_mask.nonzero(as_tuple=True)[0]
|
166
|
+
|
167
|
+
assert batch_mask.shape[0] == len(sampling_batch_info), (
|
168
|
+
f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of "
|
169
|
+
f"sampling_batch_info ({len(sampling_batch_info)})"
|
170
|
+
)
|
171
|
+
|
172
|
+
# Apply the processor to the logits
|
173
|
+
logits[batch_mask] = processor(
|
174
|
+
logits[batch_mask],
|
175
|
+
[sampling_batch_info.custom_params[i] for i in batch_indices],
|
176
|
+
)
|
177
|
+
|
178
|
+
logger.debug(
|
179
|
+
f"Custom logit processor {processor.__class__.__name__} is applied."
|
180
|
+
)
|
123
181
|
|
124
182
|
|
125
183
|
def top_k_top_p_min_p_sampling_from_probs_torch(
|
@@ -5,6 +5,7 @@ Common utilities for torchao.
|
|
5
5
|
import logging
|
6
6
|
import os
|
7
7
|
import pwd
|
8
|
+
from typing import Callable, Optional
|
8
9
|
|
9
10
|
import torch
|
10
11
|
|
@@ -27,8 +28,18 @@ def save_gemlite_cache(print_error: bool = False) -> bool:
|
|
27
28
|
return True
|
28
29
|
|
29
30
|
|
31
|
+
def proj_filter(
|
32
|
+
module: torch.nn.Module,
|
33
|
+
fqn: str,
|
34
|
+
):
|
35
|
+
"""Filter function for quantizing projection layers."""
|
36
|
+
return "proj" in fqn
|
37
|
+
|
38
|
+
|
30
39
|
def apply_torchao_config_to_model(
|
31
|
-
model: torch.nn.Module,
|
40
|
+
model: torch.nn.Module,
|
41
|
+
torchao_config: str,
|
42
|
+
filter_fn: Optional[Callable] = proj_filter,
|
32
43
|
):
|
33
44
|
"""Quantize a modelwith torchao quantization specified by torchao_config
|
34
45
|
|
@@ -49,11 +60,6 @@ def apply_torchao_config_to_model(
|
|
49
60
|
)
|
50
61
|
from torchao.quantization.observer import PerRow, PerTensor
|
51
62
|
|
52
|
-
if filter_fn is None:
|
53
|
-
|
54
|
-
def filter_fn(module, fqn):
|
55
|
-
return "proj" in fqn
|
56
|
-
|
57
63
|
if torchao_config == "" or torchao_config is None:
|
58
64
|
return model
|
59
65
|
elif "int8wo" in torchao_config:
|
@@ -6,13 +6,13 @@ from typing import List, Optional, Sequence, Tuple
|
|
6
6
|
import torch
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn.parameter import Parameter, UninitializedParameter
|
9
|
-
|
9
|
+
|
10
|
+
from sglang.srt.distributed import (
|
10
11
|
divide,
|
11
12
|
get_tensor_model_parallel_rank,
|
12
13
|
get_tensor_model_parallel_world_size,
|
13
14
|
tensor_model_parallel_all_reduce,
|
14
15
|
)
|
15
|
-
|
16
16
|
from sglang.srt.layers.parameter import BasevLLMParameter
|
17
17
|
from sglang.srt.layers.quantization.base_config import (
|
18
18
|
QuantizationConfig,
|
sglang/srt/lora/lora.py
CHANGED
@@ -19,18 +19,11 @@
|
|
19
19
|
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
|
20
20
|
|
21
21
|
|
22
|
-
import json
|
23
|
-
import os
|
24
22
|
import re
|
25
|
-
from typing import Any, Dict, List, Optional, Tuple
|
26
23
|
|
27
|
-
import safetensors.torch
|
28
24
|
import torch
|
29
25
|
from torch import nn
|
30
|
-
from vllm.model_executor.layers.vocab_parallel_embedding import
|
31
|
-
ParallelLMHead,
|
32
|
-
VocabParallelEmbedding,
|
33
|
-
)
|
26
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
34
27
|
|
35
28
|
from sglang.srt.layers.linear import (
|
36
29
|
ColumnParallelLinear,
|
@@ -38,7 +31,6 @@ from sglang.srt.layers.linear import (
|
|
38
31
|
QKVParallelLinear,
|
39
32
|
RowParallelLinear,
|
40
33
|
)
|
41
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
42
34
|
from sglang.srt.model_loader.loader import DefaultModelLoader
|
43
35
|
|
44
36
|
|
@@ -27,6 +27,7 @@ import requests
|
|
27
27
|
if __name__ == "__main__":
|
28
28
|
parser = argparse.ArgumentParser()
|
29
29
|
parser.add_argument("--url", type=str, default="http://localhost:30000")
|
30
|
+
parser.add_argument("--log-requests", action="store_true")
|
30
31
|
parser.add_argument(
|
31
32
|
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
|
32
33
|
)
|
@@ -36,6 +37,8 @@ if __name__ == "__main__":
|
|
36
37
|
response = requests.post(
|
37
38
|
args.url + "/configure_logging",
|
38
39
|
json={
|
40
|
+
"log_requests": args.log_requests,
|
41
|
+
"log_requests_level": 1, # Log full requests
|
39
42
|
"dump_requests_folder": args.dump_requests_folder,
|
40
43
|
"dump_requests_threshold": args.dump_requests_threshold,
|
41
44
|
},
|
@@ -23,6 +23,7 @@ import psutil
|
|
23
23
|
import setproctitle
|
24
24
|
import zmq
|
25
25
|
|
26
|
+
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
26
27
|
from sglang.srt.managers.io_struct import (
|
27
28
|
TokenizedEmbeddingReqInput,
|
28
29
|
TokenizedGenerateReqInput,
|
@@ -55,6 +56,7 @@ class DataParallelController:
|
|
55
56
|
|
56
57
|
def __init__(self, server_args, port_args) -> None:
|
57
58
|
# Parse args
|
59
|
+
self.max_total_num_tokens = None
|
58
60
|
self.server_args = server_args
|
59
61
|
self.port_args = port_args
|
60
62
|
self.load_balance_method = LoadBalanceMethod.from_str(
|
@@ -63,9 +65,10 @@ class DataParallelController:
|
|
63
65
|
|
64
66
|
# Init inter-process communication
|
65
67
|
self.context = zmq.Context(1 + server_args.dp_size)
|
66
|
-
|
67
|
-
self.
|
68
|
-
|
68
|
+
if server_args.node_rank == 0:
|
69
|
+
self.recv_from_tokenizer = get_zmq_socket(
|
70
|
+
self.context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
71
|
+
)
|
69
72
|
|
70
73
|
# Dispatch method
|
71
74
|
self.round_robin_counter = 0
|
@@ -75,33 +78,50 @@ class DataParallelController:
|
|
75
78
|
}
|
76
79
|
self.dispatching = dispatch_lookup[self.load_balance_method]
|
77
80
|
|
78
|
-
#
|
79
|
-
|
81
|
+
# Launch data parallel workers
|
82
|
+
self.scheduler_procs = []
|
80
83
|
self.workers = [None] * server_args.dp_size
|
81
84
|
|
85
|
+
if not server_args.enable_dp_attention:
|
86
|
+
dp_port_args = self.launch_dp_schedulers(server_args, port_args)
|
87
|
+
else:
|
88
|
+
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
|
89
|
+
|
90
|
+
# Only node rank 0 runs the real data parallel controller that dispatches the requests.
|
91
|
+
if server_args.node_rank == 0:
|
92
|
+
for dp_rank in range(server_args.dp_size):
|
93
|
+
self.workers[dp_rank] = get_zmq_socket(
|
94
|
+
self.context,
|
95
|
+
zmq.PUSH,
|
96
|
+
dp_port_args[dp_rank].scheduler_input_ipc_name,
|
97
|
+
True,
|
98
|
+
)
|
99
|
+
|
100
|
+
self.max_req_input_len = None
|
101
|
+
|
102
|
+
def launch_dp_schedulers(self, server_args, port_args):
|
103
|
+
base_gpu_id = 0
|
104
|
+
|
82
105
|
threads = []
|
83
106
|
sockets = []
|
107
|
+
dp_port_args = []
|
84
108
|
for dp_rank in range(server_args.dp_size):
|
85
109
|
tmp_port_args = PortArgs.init_new(server_args)
|
86
110
|
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
|
87
111
|
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
|
112
|
+
dp_port_args.append(tmp_port_args)
|
88
113
|
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
tmp_port_args.nccl_port = port_args.nccl_port
|
93
|
-
else:
|
94
|
-
# This port is checked free in PortArgs.init_new.
|
95
|
-
# We hold it first so that the next dp worker gets a different port
|
96
|
-
sockets.append(bind_port(tmp_port_args.nccl_port))
|
114
|
+
# This port is checked free in PortArgs.init_new.
|
115
|
+
# We hold it first so that the next dp worker gets a different port
|
116
|
+
sockets.append(bind_port(tmp_port_args.nccl_port))
|
97
117
|
|
98
118
|
# Create a thread for each worker
|
99
119
|
thread = threading.Thread(
|
100
|
-
target=self.
|
120
|
+
target=self.launch_tensor_parallel_group,
|
101
121
|
args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
|
102
122
|
)
|
103
123
|
threads.append(thread)
|
104
|
-
base_gpu_id +=
|
124
|
+
base_gpu_id += server_args.tp_size
|
105
125
|
|
106
126
|
# Free all sockets before starting the threads to launch TP workers
|
107
127
|
for sock in sockets:
|
@@ -113,26 +133,14 @@ class DataParallelController:
|
|
113
133
|
for thread in threads:
|
114
134
|
thread.join()
|
115
135
|
|
116
|
-
|
117
|
-
self,
|
118
|
-
server_args: ServerArgs,
|
119
|
-
port_args: PortArgs,
|
120
|
-
base_gpu_id: int,
|
121
|
-
dp_rank: int,
|
122
|
-
):
|
123
|
-
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
|
136
|
+
return dp_port_args
|
124
137
|
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
server_args,
|
132
|
-
port_args,
|
133
|
-
base_gpu_id,
|
134
|
-
dp_rank,
|
135
|
-
)
|
138
|
+
def launch_dp_attention_schedulers(self, server_args, port_args):
|
139
|
+
self.launch_tensor_parallel_group(server_args, port_args, 0, None)
|
140
|
+
dp_port_args = []
|
141
|
+
for dp_rank in range(server_args.dp_size):
|
142
|
+
dp_port_args.append(PortArgs.init_new(server_args, dp_rank))
|
143
|
+
return dp_port_args
|
136
144
|
|
137
145
|
def launch_tensor_parallel_group(
|
138
146
|
self,
|
@@ -141,8 +149,10 @@ class DataParallelController:
|
|
141
149
|
base_gpu_id: int,
|
142
150
|
dp_rank: int,
|
143
151
|
):
|
152
|
+
if not server_args.enable_dp_attention:
|
153
|
+
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
|
154
|
+
|
144
155
|
# Launch tensor parallel scheduler processes
|
145
|
-
scheduler_procs = []
|
146
156
|
scheduler_pipe_readers = []
|
147
157
|
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
148
158
|
tp_rank_range = range(
|
@@ -150,52 +160,39 @@ class DataParallelController:
|
|
150
160
|
tp_size_per_node * (server_args.node_rank + 1),
|
151
161
|
)
|
152
162
|
for tp_rank in tp_rank_range:
|
163
|
+
rank_port_args = port_args
|
164
|
+
|
165
|
+
if server_args.enable_dp_attention:
|
166
|
+
# dp attention has different sharding logic
|
167
|
+
_, _, dp_rank = compute_dp_attention_world_info(
|
168
|
+
server_args.enable_dp_attention,
|
169
|
+
tp_rank,
|
170
|
+
server_args.tp_size,
|
171
|
+
server_args.dp_size,
|
172
|
+
)
|
173
|
+
# compute zmq ports for this dp rank
|
174
|
+
rank_port_args = PortArgs.init_new(server_args, dp_rank)
|
175
|
+
# Data parallelism resues the tensor parallelism group,
|
176
|
+
# so all dp ranks should use the same nccl port.
|
177
|
+
rank_port_args.nccl_port = port_args.nccl_port
|
178
|
+
|
153
179
|
reader, writer = mp.Pipe(duplex=False)
|
154
180
|
gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
|
155
181
|
proc = mp.Process(
|
156
182
|
target=run_scheduler_process,
|
157
|
-
args=(server_args,
|
183
|
+
args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
|
158
184
|
)
|
159
185
|
proc.start()
|
160
|
-
scheduler_procs.append(proc)
|
186
|
+
self.scheduler_procs.append(proc)
|
161
187
|
scheduler_pipe_readers.append(reader)
|
162
188
|
|
163
|
-
|
164
|
-
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
165
|
-
)
|
166
|
-
|
167
|
-
# Wait for model to finish loading and get max token nums
|
189
|
+
# Wait for model to finish loading
|
168
190
|
scheduler_info = []
|
169
191
|
for i in range(len(scheduler_pipe_readers)):
|
170
192
|
scheduler_info.append(scheduler_pipe_readers[i].recv())
|
171
193
|
|
172
194
|
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
173
|
-
|
174
|
-
return send_to
|
175
|
-
|
176
|
-
def launch_tensor_parallel_process(
|
177
|
-
self,
|
178
|
-
server_args: ServerArgs,
|
179
|
-
port_args: PortArgs,
|
180
|
-
base_gpu_id: int,
|
181
|
-
dp_rank: int,
|
182
|
-
):
|
183
|
-
reader, writer = mp.Pipe(duplex=False)
|
184
|
-
gpu_id = base_gpu_id
|
185
|
-
tp_rank = dp_rank
|
186
|
-
proc = mp.Process(
|
187
|
-
target=run_scheduler_process,
|
188
|
-
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
|
189
|
-
)
|
190
|
-
proc.start()
|
191
|
-
send_to = get_zmq_socket(
|
192
|
-
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
193
|
-
)
|
194
|
-
|
195
|
-
scheduler_info = reader.recv()
|
196
|
-
self.max_total_num_tokens = scheduler_info["max_total_num_tokens"]
|
197
|
-
|
198
|
-
return send_to
|
195
|
+
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
|
199
196
|
|
200
197
|
def round_robin_scheduler(self, req):
|
201
198
|
self.workers[self.round_robin_counter].send_pyobj(req)
|
@@ -221,8 +218,8 @@ class DataParallelController:
|
|
221
218
|
):
|
222
219
|
self.dispatching(recv_req)
|
223
220
|
else:
|
224
|
-
# Send other control messages to
|
225
|
-
for worker in self.workers:
|
221
|
+
# Send other control messages to first worker of tp group
|
222
|
+
for worker in self.workers[:: self.server_args.tp_size]:
|
226
223
|
worker.send_pyobj(recv_req)
|
227
224
|
|
228
225
|
|
@@ -238,9 +235,19 @@ def run_data_parallel_controller_process(
|
|
238
235
|
try:
|
239
236
|
controller = DataParallelController(server_args, port_args)
|
240
237
|
pipe_writer.send(
|
241
|
-
{
|
238
|
+
{
|
239
|
+
"status": "ready",
|
240
|
+
"max_total_num_tokens": controller.max_total_num_tokens,
|
241
|
+
"max_req_input_len": controller.max_req_input_len,
|
242
|
+
}
|
242
243
|
)
|
243
|
-
|
244
|
+
if server_args.node_rank == 0:
|
245
|
+
controller.event_loop()
|
246
|
+
for proc in controller.scheduler_procs:
|
247
|
+
proc.join()
|
248
|
+
logger.error(
|
249
|
+
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
|
250
|
+
)
|
244
251
|
except Exception:
|
245
252
|
traceback = get_exception_traceback()
|
246
253
|
logger.error(f"DataParallelController hit an exception: {traceback}")
|
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
import dataclasses
|
17
17
|
import logging
|
18
|
+
import os
|
18
19
|
import signal
|
19
20
|
from collections import OrderedDict
|
20
21
|
from typing import Dict, List, Union
|
@@ -35,6 +36,12 @@ from sglang.utils import find_printable_text, get_exception_traceback
|
|
35
36
|
|
36
37
|
logger = logging.getLogger(__name__)
|
37
38
|
|
39
|
+
# Maximum number of request states that detokenizer can hold. When exceeded,
|
40
|
+
# oldest request states will be evicted. Default: 65536 (1<<16).
|
41
|
+
# For more details, see: https://github.com/sgl-project/sglang/issues/2812
|
42
|
+
# Use power of 2 values for better memory allocation.
|
43
|
+
DETOKENIZER_MAX_STATES = int(os.environ.get("SGLANG_DETOKENIZER_MAX_STATES", 1 << 16))
|
44
|
+
|
38
45
|
|
39
46
|
@dataclasses.dataclass
|
40
47
|
class DecodeStatus:
|
@@ -58,10 +65,10 @@ class DetokenizerManager:
|
|
58
65
|
# Init inter-process communication
|
59
66
|
context = zmq.Context(2)
|
60
67
|
self.recv_from_scheduler = get_zmq_socket(
|
61
|
-
context, zmq.PULL, port_args.detokenizer_ipc_name
|
68
|
+
context, zmq.PULL, port_args.detokenizer_ipc_name, True
|
62
69
|
)
|
63
70
|
self.send_to_tokenizer = get_zmq_socket(
|
64
|
-
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
71
|
+
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
65
72
|
)
|
66
73
|
|
67
74
|
if server_args.skip_tokenizer_init:
|
@@ -71,9 +78,10 @@ class DetokenizerManager:
|
|
71
78
|
server_args.tokenizer_path,
|
72
79
|
tokenizer_mode=server_args.tokenizer_mode,
|
73
80
|
trust_remote_code=server_args.trust_remote_code,
|
81
|
+
revision=server_args.revision,
|
74
82
|
)
|
75
83
|
|
76
|
-
self.decode_status = LimitedCapacityDict()
|
84
|
+
self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES)
|
77
85
|
|
78
86
|
def trim_matched_stop(
|
79
87
|
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
|
@@ -155,7 +163,17 @@ class DetokenizerManager:
|
|
155
163
|
# Incremental decoding
|
156
164
|
output_strs = []
|
157
165
|
for i in range(bs):
|
158
|
-
|
166
|
+
try:
|
167
|
+
s = self.decode_status[recv_obj.rids[i]]
|
168
|
+
except KeyError:
|
169
|
+
raise RuntimeError(
|
170
|
+
f"Decode status not found for request {recv_obj.rids[i]}. "
|
171
|
+
"It may be due to the request being evicted from the decode status due to memory pressure. "
|
172
|
+
"Please increase the maximum number of requests by setting "
|
173
|
+
"the SGLANG_DETOKENIZER_MAX_STATES environment variable to a bigger value than the default value. "
|
174
|
+
f"The current value is {DETOKENIZER_MAX_STATES}. "
|
175
|
+
"For more details, see: https://github.com/sgl-project/sglang/issues/2812"
|
176
|
+
)
|
159
177
|
new_text = read_texts[i][len(surr_texts[i]) :]
|
160
178
|
if recv_obj.finished_reasons[i] is None:
|
161
179
|
# Streaming chunk: update the decode status
|
@@ -183,6 +201,7 @@ class DetokenizerManager:
|
|
183
201
|
prompt_tokens=recv_obj.prompt_tokens,
|
184
202
|
completion_tokens=recv_obj.completion_tokens,
|
185
203
|
cached_tokens=recv_obj.cached_tokens,
|
204
|
+
spec_verify_ct=recv_obj.spec_verify_ct,
|
186
205
|
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
|
187
206
|
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
|
188
207
|
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
|
@@ -191,13 +210,12 @@ class DetokenizerManager:
|
|
191
210
|
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
|
192
211
|
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
|
193
212
|
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
|
194
|
-
normalized_prompt_logprob=recv_obj.normalized_prompt_logprob,
|
195
213
|
)
|
196
214
|
)
|
197
215
|
|
198
216
|
|
199
217
|
class LimitedCapacityDict(OrderedDict):
|
200
|
-
def __init__(self, capacity
|
218
|
+
def __init__(self, capacity: int, *args, **kwargs):
|
201
219
|
super().__init__(*args, **kwargs)
|
202
220
|
self.capacity = capacity
|
203
221
|
|