sglang 0.3.5.post2__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +2 -2
- sglang/bench_latency.py +1 -553
- sglang/bench_offline_throughput.py +48 -20
- sglang/bench_one_batch.py +472 -0
- sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
- sglang/bench_serving.py +125 -6
- sglang/check_env.py +3 -6
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +2 -2
- sglang/srt/configs/model_config.py +13 -14
- sglang/srt/constrained/__init__.py +13 -14
- sglang/srt/constrained/base_grammar_backend.py +13 -15
- sglang/srt/constrained/outlines_backend.py +28 -17
- sglang/srt/constrained/outlines_jump_forward.py +13 -15
- sglang/srt/constrained/xgrammar_backend.py +47 -58
- sglang/srt/conversation.py +13 -15
- sglang/srt/hf_transformers_utils.py +13 -15
- sglang/srt/layers/activation.py +16 -13
- sglang/srt/layers/attention/flashinfer_backend.py +106 -54
- sglang/srt/layers/attention/triton_backend.py +9 -7
- sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
- sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
- sglang/srt/layers/custom_op_util.py +25 -0
- sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
- sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +11 -4
- sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
- sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
- sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
- sglang/srt/layers/fused_moe_triton/layer.py +633 -0
- sglang/srt/layers/layernorm.py +17 -15
- sglang/srt/layers/logits_processor.py +23 -25
- sglang/srt/layers/quantization/__init__.py +77 -17
- sglang/srt/layers/radix_attention.py +13 -15
- sglang/srt/layers/rotary_embedding.py +13 -13
- sglang/srt/layers/sampler.py +4 -8
- sglang/srt/layers/torchao_utils.py +2 -0
- sglang/srt/lora/lora.py +13 -14
- sglang/srt/lora/lora_config.py +13 -14
- sglang/srt/lora/lora_manager.py +22 -24
- sglang/srt/managers/data_parallel_controller.py +98 -27
- sglang/srt/managers/detokenizer_manager.py +13 -15
- sglang/srt/managers/io_struct.py +63 -21
- sglang/srt/managers/schedule_batch.py +154 -59
- sglang/srt/managers/schedule_policy.py +18 -16
- sglang/srt/managers/scheduler.py +278 -109
- sglang/srt/managers/session_controller.py +61 -0
- sglang/srt/managers/tokenizer_manager.py +63 -18
- sglang/srt/managers/tp_worker.py +25 -16
- sglang/srt/managers/tp_worker_overlap_thread.py +62 -67
- sglang/srt/metrics/collector.py +13 -15
- sglang/srt/metrics/func_timer.py +13 -15
- sglang/srt/mm_utils.py +13 -14
- sglang/srt/model_executor/cuda_graph_runner.py +63 -25
- sglang/srt/model_executor/forward_batch_info.py +128 -32
- sglang/srt/model_executor/model_runner.py +132 -64
- sglang/srt/model_parallel.py +98 -0
- sglang/srt/models/chatglm.py +15 -16
- sglang/srt/models/commandr.py +15 -16
- sglang/srt/models/dbrx.py +15 -16
- sglang/srt/models/deepseek.py +15 -15
- sglang/srt/models/deepseek_v2.py +162 -59
- sglang/srt/models/exaone.py +14 -15
- sglang/srt/models/gemma.py +14 -14
- sglang/srt/models/gemma2.py +31 -25
- sglang/srt/models/gemma2_reward.py +13 -14
- sglang/srt/models/gpt_bigcode.py +14 -14
- sglang/srt/models/grok.py +15 -15
- sglang/srt/models/internlm2.py +13 -15
- sglang/srt/models/internlm2_reward.py +13 -14
- sglang/srt/models/llama.py +21 -21
- sglang/srt/models/llama_classification.py +13 -14
- sglang/srt/models/llama_reward.py +13 -14
- sglang/srt/models/llava.py +14 -16
- sglang/srt/models/llavavid.py +14 -16
- sglang/srt/models/minicpm.py +13 -15
- sglang/srt/models/minicpm3.py +13 -15
- sglang/srt/models/mistral.py +13 -15
- sglang/srt/models/mixtral.py +15 -15
- sglang/srt/models/mixtral_quant.py +14 -14
- sglang/srt/models/olmo.py +22 -20
- sglang/srt/models/olmoe.py +23 -20
- sglang/srt/models/phi3_small.py +447 -0
- sglang/srt/models/qwen.py +14 -14
- sglang/srt/models/qwen2.py +22 -19
- sglang/srt/models/qwen2_moe.py +17 -18
- sglang/srt/models/qwen2_vl.py +13 -6
- sglang/srt/models/stablelm.py +18 -16
- sglang/srt/models/torch_native_llama.py +107 -93
- sglang/srt/models/xverse.py +13 -14
- sglang/srt/models/xverse_moe.py +15 -16
- sglang/srt/models/yivl.py +13 -15
- sglang/srt/openai_api/adapter.py +19 -17
- sglang/srt/openai_api/protocol.py +14 -16
- sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
- sglang/srt/sampling/sampling_batch_info.py +61 -57
- sglang/srt/sampling/sampling_params.py +14 -16
- sglang/srt/server.py +86 -35
- sglang/srt/server_args.py +96 -80
- sglang/srt/utils.py +266 -68
- sglang/test/few_shot_gsm8k.py +8 -4
- sglang/test/runners.py +38 -20
- sglang/test/srt/sampling/penaltylib/utils.py +23 -21
- sglang/test/test_utils.py +31 -20
- sglang/version.py +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +66 -57
- sglang-0.3.6.post1.dist-info/RECORD +164 -0
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +1 -1
- sglang/srt/layers/fused_moe/__init__.py +0 -1
- sglang-0.3.5.post2.dist-info/RECORD +0 -156
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""Logits processing."""
|
17
15
|
|
18
16
|
import dataclasses
|
@@ -62,21 +60,21 @@ class LogitsMetadata:
|
|
62
60
|
|
63
61
|
@classmethod
|
64
62
|
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
63
|
+
extend_logprob_pruned_lens_cpu = None
|
64
|
+
|
65
65
|
if forward_batch.return_logprob:
|
66
66
|
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
67
|
+
if forward_batch.forward_mode.is_extend():
|
68
|
+
extend_logprob_pruned_lens_cpu = [
|
69
|
+
extend_len - start_len
|
70
|
+
for extend_len, start_len in zip(
|
71
|
+
forward_batch.extend_seq_lens_cpu,
|
72
|
+
forward_batch.extend_logprob_start_lens_cpu,
|
73
|
+
)
|
74
|
+
]
|
67
75
|
else:
|
68
76
|
return_top_logprob = False
|
69
77
|
|
70
|
-
if forward_batch.forward_mode.is_extend():
|
71
|
-
extend_logprob_pruned_lens_cpu = [
|
72
|
-
extend_len - start_len
|
73
|
-
for extend_len, start_len in zip(
|
74
|
-
forward_batch.extend_seq_lens,
|
75
|
-
forward_batch.extend_logprob_start_lens_cpu,
|
76
|
-
)
|
77
|
-
]
|
78
|
-
else:
|
79
|
-
extend_logprob_pruned_lens_cpu = None
|
80
78
|
return cls(
|
81
79
|
forward_mode=forward_batch.forward_mode,
|
82
80
|
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
@@ -1,18 +1,19 @@
|
|
1
1
|
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
|
2
2
|
|
3
|
-
from typing import Dict, Type
|
3
|
+
from typing import Callable, Dict, Optional, Type
|
4
4
|
|
5
|
+
import torch
|
5
6
|
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
6
7
|
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
7
8
|
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
8
9
|
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
9
|
-
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
|
10
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
|
10
11
|
CompressedTensorsConfig,
|
11
12
|
)
|
12
13
|
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
13
14
|
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
14
15
|
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
15
|
-
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
16
|
+
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
16
17
|
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
17
18
|
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
18
19
|
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
|
@@ -30,8 +31,6 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
30
31
|
"tpu_int8": Int8TpuConfig,
|
31
32
|
"fp8": Fp8Config,
|
32
33
|
"fbgemm_fp8": FBGEMMFp8Config,
|
33
|
-
# The order of gptq methods is important for config.py iteration over
|
34
|
-
# override_quantization_method(..)
|
35
34
|
"marlin": MarlinConfig,
|
36
35
|
"gguf": GGUFConfig,
|
37
36
|
"gptq_marlin_24": GPTQMarlin24Config,
|
@@ -47,20 +46,68 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
47
46
|
|
48
47
|
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
49
48
|
if quantization not in QUANTIZATION_METHODS:
|
50
|
-
raise ValueError(
|
49
|
+
raise ValueError(
|
50
|
+
f"Invalid quantization method: {quantization}. "
|
51
|
+
f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
|
52
|
+
)
|
51
53
|
return QUANTIZATION_METHODS[quantization]
|
52
54
|
|
53
55
|
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
56
|
+
def fp8_moe_apply(
|
57
|
+
self,
|
58
|
+
layer: torch.nn.Module,
|
59
|
+
x: torch.Tensor,
|
60
|
+
router_logits: torch.Tensor,
|
61
|
+
top_k: int,
|
62
|
+
renormalize: bool,
|
63
|
+
use_grouped_topk: bool,
|
64
|
+
topk_group: Optional[int] = None,
|
65
|
+
num_expert_group: Optional[int] = None,
|
66
|
+
custom_routing_function: Optional[Callable] = None,
|
67
|
+
) -> torch.Tensor:
|
68
|
+
"""Enhanced apply method for FP8 MoE."""
|
69
|
+
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
70
|
+
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
|
71
|
+
|
72
|
+
# Expert selection
|
73
|
+
topk_weights, topk_ids = FusedMoE.select_experts(
|
74
|
+
hidden_states=x,
|
75
|
+
router_logits=router_logits,
|
76
|
+
use_grouped_topk=use_grouped_topk,
|
77
|
+
top_k=top_k,
|
78
|
+
renormalize=renormalize,
|
79
|
+
topk_group=topk_group,
|
80
|
+
num_expert_group=num_expert_group,
|
81
|
+
custom_routing_function=custom_routing_function,
|
82
|
+
)
|
83
|
+
|
84
|
+
# Expert fusion with FP8 quantization
|
85
|
+
return fused_experts(
|
86
|
+
x,
|
87
|
+
layer.w13_weight,
|
88
|
+
layer.w2_weight,
|
89
|
+
topk_weights=topk_weights,
|
90
|
+
topk_ids=topk_ids,
|
91
|
+
inplace=True,
|
92
|
+
use_fp8_w8a8=True,
|
93
|
+
w1_scale=layer.w13_weight_scale,
|
94
|
+
w2_scale=layer.w2_weight_scale,
|
95
|
+
a1_scale=layer.w13_input_scale,
|
96
|
+
a2_scale=layer.w2_input_scale,
|
97
|
+
)
|
98
|
+
|
99
|
+
|
100
|
+
def fp8_get_quant_method(self, layer, prefix):
|
101
|
+
"""Enhanced get_quant_method for FP8 config."""
|
102
|
+
from vllm.model_executor.layers.linear import LinearBase
|
103
|
+
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
104
|
+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
105
|
+
is_layer_skipped,
|
106
|
+
)
|
107
|
+
|
108
|
+
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
109
|
+
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
59
110
|
|
60
|
-
"""
|
61
|
-
def fp8_get_quant_method(
|
62
|
-
self, layer: torch.nn.Module, prefix: str
|
63
|
-
) -> Optional["QuantizeMethodBase"]:
|
64
111
|
if isinstance(layer, LinearBase):
|
65
112
|
if is_layer_skipped(prefix, self.ignored_layers):
|
66
113
|
return UnquantizedLinearMethod()
|
@@ -70,5 +117,18 @@ def fp8_get_quant_method(
|
|
70
117
|
return None
|
71
118
|
|
72
119
|
|
73
|
-
|
74
|
-
"""
|
120
|
+
def apply_monkey_patches():
|
121
|
+
"""Apply all monkey patches in one place."""
|
122
|
+
setattr(Fp8MoEMethod, "apply", fp8_moe_apply)
|
123
|
+
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
|
124
|
+
|
125
|
+
|
126
|
+
# Apply patches when module is imported
|
127
|
+
apply_monkey_patches()
|
128
|
+
|
129
|
+
|
130
|
+
__all__ = [
|
131
|
+
"QuantizationConfig",
|
132
|
+
"get_quantization_config",
|
133
|
+
"QUANTIZATION_METHODS",
|
134
|
+
]
|
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""Radix attention."""
|
17
15
|
|
18
16
|
from torch import nn
|
@@ -1,16 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
14
|
"""MRotaryEmbedding"""
|
15
15
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
16
16
|
|
sglang/srt/layers/sampler.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
import logging
|
2
|
-
import os
|
3
2
|
from typing import Union
|
4
3
|
|
5
4
|
import torch
|
@@ -8,7 +7,7 @@ from torch import nn
|
|
8
7
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
9
8
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
10
9
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
11
|
-
from sglang.srt.utils import is_flashinfer_available
|
10
|
+
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
|
12
11
|
|
13
12
|
if is_flashinfer_available():
|
14
13
|
from flashinfer.sampling import (
|
@@ -19,17 +18,13 @@ if is_flashinfer_available():
|
|
19
18
|
)
|
20
19
|
|
21
20
|
|
22
|
-
# Crash on warning if we are running CI tests
|
23
|
-
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
24
|
-
|
25
|
-
|
26
21
|
logger = logging.getLogger(__name__)
|
27
22
|
|
28
23
|
|
29
24
|
class Sampler(nn.Module):
|
30
25
|
def __init__(self):
|
31
26
|
super().__init__()
|
32
|
-
self.use_nan_detectioin =
|
27
|
+
self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"]
|
33
28
|
|
34
29
|
def forward(
|
35
30
|
self,
|
@@ -46,7 +41,8 @@ class Sampler(nn.Module):
|
|
46
41
|
logits = torch.where(
|
47
42
|
torch.isnan(logits), torch.full_like(logits, -1e5), logits
|
48
43
|
)
|
49
|
-
|
44
|
+
if crash_on_warnings():
|
45
|
+
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
50
46
|
|
51
47
|
if sampling_info.is_all_greedy:
|
52
48
|
# Use torch.argmax if all requests use greedy sampling
|
sglang/srt/lora/lora.py
CHANGED
@@ -1,17 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
15
14
|
|
16
15
|
# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
|
17
16
|
# and "Punica: Multi-Tenant LoRA Serving"
|
sglang/srt/lora/lora_config.py
CHANGED
@@ -1,17 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
15
14
|
|
16
15
|
import json
|
17
16
|
import os
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -1,22 +1,20 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
15
14
|
|
16
15
|
# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
|
17
16
|
# and "Punica: Multi-Tenant LoRA Serving"
|
18
17
|
|
19
|
-
|
20
18
|
import logging
|
21
19
|
import re
|
22
20
|
|
@@ -146,9 +144,9 @@ class LoRAManager:
|
|
146
144
|
}
|
147
145
|
else:
|
148
146
|
logger.warning(
|
149
|
-
|
150
|
-
|
151
|
-
|
147
|
+
"WARNING: get_module_name() is not defined, "
|
148
|
+
"which is used to map config module name to model implementation module name."
|
149
|
+
"Use the default one, but please check if it is correct for your model."
|
152
150
|
)
|
153
151
|
self.target_modules = {
|
154
152
|
get_module_name(module) for module in self.origin_target_modules
|
@@ -194,9 +192,9 @@ class LoRAManager:
|
|
194
192
|
hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
|
195
193
|
else:
|
196
194
|
logger.warning(
|
197
|
-
|
198
|
-
|
199
|
-
|
195
|
+
"WARNING: get_hidden_dim() is not defined, "
|
196
|
+
"which is used to get the hidden dim for different lora modules"
|
197
|
+
"Use the default one, but please check if it is correct for your model."
|
200
198
|
)
|
201
199
|
hidden_dim_A, _ = get_hidden_dim(module_A, self.base_hf_config)
|
202
200
|
c = self.loras[-1].get_stacked_multiply(module_A)
|
@@ -218,9 +216,9 @@ class LoRAManager:
|
|
218
216
|
_, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
|
219
217
|
else:
|
220
218
|
logger.warning(
|
221
|
-
|
222
|
-
|
223
|
-
|
219
|
+
"WARNING: get_hidden_dim() is not defined, "
|
220
|
+
"which is used to get the hidden dim for different lora modules"
|
221
|
+
"Use the default one, but please check if it is correct for your model."
|
224
222
|
)
|
225
223
|
_, hidden_dim_B = get_hidden_dim(module_B, self.base_hf_config)
|
226
224
|
c = self.loras[-1].get_stacked_multiply(module_B)
|
@@ -1,22 +1,21 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""A controller that dispatches requests to multiple data parallel workers."""
|
17
15
|
|
18
16
|
import logging
|
19
17
|
import multiprocessing as mp
|
18
|
+
import threading
|
20
19
|
from enum import Enum, auto
|
21
20
|
|
22
21
|
import zmq
|
@@ -28,6 +27,7 @@ from sglang.srt.managers.io_struct import (
|
|
28
27
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
29
28
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
30
29
|
from sglang.srt.utils import (
|
30
|
+
bind_port,
|
31
31
|
configure_logger,
|
32
32
|
get_zmq_socket,
|
33
33
|
kill_parent_process,
|
@@ -80,20 +80,62 @@ class DataParallelController:
|
|
80
80
|
|
81
81
|
# Start data parallel workers
|
82
82
|
base_gpu_id = 0
|
83
|
-
self.workers = []
|
83
|
+
self.workers = [None] * server_args.dp_size
|
84
|
+
|
85
|
+
threads = []
|
86
|
+
sockets = []
|
84
87
|
for dp_rank in range(server_args.dp_size):
|
85
88
|
tmp_port_args = PortArgs.init_new(server_args)
|
89
|
+
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
|
86
90
|
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
|
87
91
|
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
92
|
+
if server_args.enable_dp_attention:
|
93
|
+
# Data parallelism resues the tensor parallelism group,
|
94
|
+
# so all dp ranks should use the same nccl port.
|
95
|
+
tmp_port_args.nccl_port = port_args.nccl_port
|
96
|
+
else:
|
97
|
+
# This port is checked free in PortArgs.init_new.
|
98
|
+
# We hold it first so that the next dp worker gets a different port
|
99
|
+
sockets.append(bind_port(tmp_port_args.nccl_port))
|
100
|
+
|
101
|
+
# Create a thread for each worker
|
102
|
+
thread = threading.Thread(
|
103
|
+
target=self.launch_worker_func,
|
104
|
+
args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
|
93
105
|
)
|
106
|
+
threads.append(thread)
|
107
|
+
base_gpu_id += 1 if server_args.enable_dp_attention else server_args.tp_size
|
108
|
+
|
109
|
+
# Free all sockets before starting the threads to launch TP workers
|
110
|
+
for sock in sockets:
|
111
|
+
sock.close()
|
112
|
+
|
113
|
+
# Start all threads
|
114
|
+
for thread in threads:
|
115
|
+
thread.start()
|
116
|
+
for thread in threads:
|
117
|
+
thread.join()
|
118
|
+
|
119
|
+
def launch_worker_func(
|
120
|
+
self,
|
121
|
+
server_args: ServerArgs,
|
122
|
+
port_args: PortArgs,
|
123
|
+
base_gpu_id: int,
|
124
|
+
dp_rank: int,
|
125
|
+
):
|
126
|
+
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
|
94
127
|
|
95
|
-
|
96
|
-
|
128
|
+
launch_func_ = (
|
129
|
+
self.launch_tensor_parallel_process
|
130
|
+
if server_args.enable_dp_attention
|
131
|
+
else self.launch_tensor_parallel_group
|
132
|
+
)
|
133
|
+
self.workers[dp_rank] = launch_func_(
|
134
|
+
server_args,
|
135
|
+
port_args,
|
136
|
+
base_gpu_id,
|
137
|
+
dp_rank,
|
138
|
+
)
|
97
139
|
|
98
140
|
def launch_tensor_parallel_group(
|
99
141
|
self,
|
@@ -112,7 +154,7 @@ class DataParallelController:
|
|
112
154
|
)
|
113
155
|
for tp_rank in tp_rank_range:
|
114
156
|
reader, writer = mp.Pipe(duplex=False)
|
115
|
-
gpu_id = base_gpu_id + tp_rank % tp_size_per_node
|
157
|
+
gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
|
116
158
|
proc = mp.Process(
|
117
159
|
target=run_scheduler_process,
|
118
160
|
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
|
@@ -125,9 +167,36 @@ class DataParallelController:
|
|
125
167
|
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
126
168
|
)
|
127
169
|
|
128
|
-
# Wait for model to finish loading
|
170
|
+
# Wait for model to finish loading and get max token nums
|
171
|
+
scheduler_info = []
|
129
172
|
for i in range(len(scheduler_pipe_readers)):
|
130
|
-
scheduler_pipe_readers[i].recv()
|
173
|
+
scheduler_info.append(scheduler_pipe_readers[i].recv())
|
174
|
+
|
175
|
+
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
176
|
+
|
177
|
+
return send_to
|
178
|
+
|
179
|
+
def launch_tensor_parallel_process(
|
180
|
+
self,
|
181
|
+
server_args: ServerArgs,
|
182
|
+
port_args: PortArgs,
|
183
|
+
base_gpu_id: int,
|
184
|
+
dp_rank: int,
|
185
|
+
):
|
186
|
+
reader, writer = mp.Pipe(duplex=False)
|
187
|
+
gpu_id = base_gpu_id
|
188
|
+
tp_rank = dp_rank
|
189
|
+
proc = mp.Process(
|
190
|
+
target=run_scheduler_process,
|
191
|
+
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
|
192
|
+
)
|
193
|
+
proc.start()
|
194
|
+
send_to = get_zmq_socket(
|
195
|
+
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
196
|
+
)
|
197
|
+
|
198
|
+
scheduler_info = reader.recv()
|
199
|
+
self.max_total_num_tokens = scheduler_info["max_total_num_tokens"]
|
131
200
|
|
132
201
|
return send_to
|
133
202
|
|
@@ -170,7 +239,9 @@ def run_data_parallel_controller_process(
|
|
170
239
|
|
171
240
|
try:
|
172
241
|
controller = DataParallelController(server_args, port_args)
|
173
|
-
pipe_writer.send(
|
242
|
+
pipe_writer.send(
|
243
|
+
{"status": "ready", "max_total_num_tokens": controller.max_total_num_tokens}
|
244
|
+
)
|
174
245
|
controller.event_loop()
|
175
246
|
except Exception:
|
176
247
|
msg = get_exception_traceback()
|