sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -6,11 +6,14 @@ from typing import List, Mapping, Tuple, Union
|
|
6
6
|
import torch
|
7
7
|
|
8
8
|
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
9
|
-
from sglang.srt.utils import is_cuda
|
9
|
+
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
|
10
10
|
|
11
11
|
_is_cuda = is_cuda()
|
12
|
+
_is_npu = is_npu()
|
13
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
14
|
+
_is_cpu = is_cpu()
|
12
15
|
|
13
|
-
if not _is_cuda:
|
16
|
+
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
14
17
|
from vllm._custom_ops import scaled_fp8_quant
|
15
18
|
|
16
19
|
|
@@ -18,7 +18,6 @@ from typing import Optional
|
|
18
18
|
|
19
19
|
from torch import nn
|
20
20
|
|
21
|
-
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
22
21
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
23
22
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
24
23
|
|
@@ -52,9 +51,9 @@ class RadixAttention(nn.Module):
|
|
52
51
|
sliding_window_size: int = -1,
|
53
52
|
is_cross_attention: bool = False,
|
54
53
|
quant_config: Optional[QuantizationConfig] = None,
|
55
|
-
attn_type=AttentionType.DECODER,
|
56
|
-
prefix: str = "",
|
54
|
+
attn_type: AttentionType = AttentionType.DECODER,
|
57
55
|
use_irope: bool = False,
|
56
|
+
prefix: str = "",
|
58
57
|
):
|
59
58
|
super().__init__()
|
60
59
|
self.tp_q_head_num = num_heads
|
@@ -8,10 +8,13 @@ import torch
|
|
8
8
|
import torch.nn as nn
|
9
9
|
|
10
10
|
from sglang.srt.custom_op import CustomOp
|
11
|
-
from sglang.srt.utils import is_cuda, is_hip
|
11
|
+
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
|
12
12
|
|
13
13
|
_is_cuda = is_cuda()
|
14
14
|
_is_hip = is_hip()
|
15
|
+
_is_npu = is_npu()
|
16
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
17
|
+
_is_cpu = is_cpu()
|
15
18
|
|
16
19
|
if _is_cuda:
|
17
20
|
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
@@ -84,7 +87,9 @@ class RotaryEmbedding(CustomOp):
|
|
84
87
|
if not _is_cuda:
|
85
88
|
cache = cache.to(dtype)
|
86
89
|
|
87
|
-
if
|
90
|
+
if (
|
91
|
+
not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]
|
92
|
+
) and not (_is_cpu and _is_cpu_amx_available):
|
88
93
|
from vllm._custom_ops import rotary_embedding
|
89
94
|
|
90
95
|
self.vllm_rotary_embedding = rotary_embedding
|
@@ -147,6 +152,26 @@ class RotaryEmbedding(CustomOp):
|
|
147
152
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
148
153
|
return query, key
|
149
154
|
|
155
|
+
def forward_cpu(
|
156
|
+
self,
|
157
|
+
positions: torch.Tensor,
|
158
|
+
query: torch.Tensor,
|
159
|
+
key: torch.Tensor,
|
160
|
+
offsets: Optional[torch.Tensor] = None,
|
161
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
162
|
+
positions = torch.add(positions, offsets) if offsets is not None else positions
|
163
|
+
if _is_cpu_amx_available:
|
164
|
+
return torch.ops.sgl_kernel.rotary_embedding_cpu(
|
165
|
+
positions,
|
166
|
+
query,
|
167
|
+
key,
|
168
|
+
self.head_size,
|
169
|
+
self.cos_sin_cache,
|
170
|
+
self.is_neox_style,
|
171
|
+
)
|
172
|
+
else:
|
173
|
+
return self.forward_native(positions, query, key, offsets)
|
174
|
+
|
150
175
|
def forward_cuda(
|
151
176
|
self,
|
152
177
|
positions: torch.Tensor,
|
@@ -696,6 +721,21 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
696
721
|
key = key_rot
|
697
722
|
return query.to(dtype), key.to(dtype)
|
698
723
|
|
724
|
+
def forward_cpu(
|
725
|
+
self,
|
726
|
+
positions: torch.Tensor,
|
727
|
+
query: torch.Tensor,
|
728
|
+
key: torch.Tensor,
|
729
|
+
offsets: Optional[torch.Tensor] = None,
|
730
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
731
|
+
positions = torch.add(positions, offsets) if offsets is not None else positions
|
732
|
+
if _is_cpu_amx_available:
|
733
|
+
return torch.ops.sgl_kernel.rotary_embedding_cpu(
|
734
|
+
positions, query, key, self.head_size, self.cos_sin_cache, False
|
735
|
+
)
|
736
|
+
else:
|
737
|
+
return self.forward_native(positions, query, key, offsets)
|
738
|
+
|
699
739
|
|
700
740
|
class Llama3RotaryEmbedding(RotaryEmbedding):
|
701
741
|
|
sglang/srt/layers/sampler.py
CHANGED
sglang/srt/lora/lora_manager.py
CHANGED
@@ -16,7 +16,7 @@
|
|
16
16
|
# and "Punica: Multi-Tenant LoRA Serving"
|
17
17
|
|
18
18
|
import logging
|
19
|
-
from typing import Dict,
|
19
|
+
from typing import Dict, Set, Tuple
|
20
20
|
|
21
21
|
import torch
|
22
22
|
|
@@ -45,7 +45,6 @@ class LoRAManager:
|
|
45
45
|
def __init__(
|
46
46
|
self,
|
47
47
|
base_model: torch.nn.Module,
|
48
|
-
lora_paths: Dict[str, str],
|
49
48
|
base_hf_config: AutoConfig,
|
50
49
|
max_loras_per_batch: int,
|
51
50
|
load_config: LoadConfig,
|
@@ -55,7 +54,6 @@ class LoRAManager:
|
|
55
54
|
tp_rank: int = 0,
|
56
55
|
):
|
57
56
|
self.base_model: torch.nn.Module = base_model
|
58
|
-
self.lora_paths: Dict[str, str] = lora_paths
|
59
57
|
self.base_hf_config: AutoConfig = base_hf_config
|
60
58
|
self.max_loras_per_batch: int = max_loras_per_batch
|
61
59
|
self.load_config: LoadConfig = load_config
|
@@ -69,8 +67,8 @@ class LoRAManager:
|
|
69
67
|
backend_type = get_backend_from_name(lora_backend)
|
70
68
|
self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
|
71
69
|
|
72
|
-
|
73
|
-
self.
|
70
|
+
# Initialize mutable internal state of the LoRAManager.
|
71
|
+
self.init_state()
|
74
72
|
|
75
73
|
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
|
76
74
|
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
|
@@ -81,7 +79,7 @@ class LoRAManager:
|
|
81
79
|
seg_indptr=torch.zeros(
|
82
80
|
self.max_bs_in_cuda_graph + 1, dtype=torch.int32
|
83
81
|
),
|
84
|
-
max_len=
|
82
|
+
max_len=1,
|
85
83
|
weight_indices=torch.zeros(
|
86
84
|
self.max_bs_in_cuda_graph, dtype=torch.int32
|
87
85
|
),
|
@@ -89,76 +87,103 @@ class LoRAManager:
|
|
89
87
|
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
|
90
88
|
)
|
91
89
|
|
92
|
-
|
93
|
-
|
94
|
-
|
90
|
+
# Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
|
91
|
+
# across batches.
|
92
|
+
self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph].fill_(1)
|
93
|
+
torch.cumsum(
|
94
|
+
self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph],
|
95
|
+
dim=0,
|
96
|
+
out=self.cuda_graph_batch_info.seg_indptr[
|
97
|
+
1 : self.max_bs_in_cuda_graph + 1
|
98
|
+
],
|
99
|
+
)
|
95
100
|
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
self.configs[name] = LoRAConfig(path)
|
101
|
-
self.hf_target_names.update(self.configs[name].target_modules)
|
101
|
+
def load_lora_adapters(self, lora_paths: Dict[str, str]):
|
102
|
+
"""
|
103
|
+
Load LoRA adapters from the specified paths.
|
104
|
+
TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading.
|
102
105
|
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
lora_A, lora_B = get_normalized_lora_weight_names(module)
|
108
|
-
weights_A += lora_A
|
109
|
-
weights_B += lora_B
|
110
|
-
self.lora_weight_names: Tuple[Set[str]] = set(weights_A), set(weights_B)
|
106
|
+
Args:
|
107
|
+
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
|
108
|
+
If a LoRA adapter is already loaded, it will be skipped with a warning.
|
109
|
+
"""
|
111
110
|
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
self.load_config,
|
120
|
-
self.lora_backend,
|
121
|
-
)
|
122
|
-
lora_adapter.initialize_weights()
|
123
|
-
self.loras[name] = lora_adapter
|
111
|
+
for lora_name, lora_path in lora_paths.items():
|
112
|
+
if lora_name in self.loras:
|
113
|
+
logger.warning(
|
114
|
+
f"LoRA adapter {lora_name} is already loaded."
|
115
|
+
"If you want to reload it, please unload it first."
|
116
|
+
)
|
117
|
+
continue
|
124
118
|
|
125
|
-
|
126
|
-
self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
|
119
|
+
self.configs[lora_name] = LoRAConfig(lora_path)
|
127
120
|
|
128
|
-
|
129
|
-
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
130
|
-
max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
|
131
|
-
scaling = list(self.loras.values())[0].scaling
|
132
|
-
assert all(x.hf_config["r"] == max_lora_dim for x in self.configs.values())
|
133
|
-
assert all(x.scaling == scaling for x in self.loras.values())
|
121
|
+
self.update_state_from_configs()
|
134
122
|
|
135
|
-
|
136
|
-
|
123
|
+
def unload_lora_adapters(self, lora_names: Set[str]):
|
124
|
+
"""
|
125
|
+
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
126
|
+
delete the corresponding LoRA modules.
|
137
127
|
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
self.
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
self.tp_rank,
|
147
|
-
self.lora_modules,
|
148
|
-
)
|
128
|
+
Args:
|
129
|
+
lora_names (Set[str]): A set of LoRA adapter names to unload.
|
130
|
+
"""
|
131
|
+
for lora_name in lora_names:
|
132
|
+
if lora_name in self.loras:
|
133
|
+
del self.configs[lora_name]
|
134
|
+
else:
|
135
|
+
logger.warning(f"LoRA adapter {lora_name} is not loaded.")
|
149
136
|
|
150
|
-
|
151
|
-
self.memory_pool.init_buffers(self.lora_weight_names, self.base_model)
|
137
|
+
self.update_state_from_configs()
|
152
138
|
|
153
139
|
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
154
140
|
# load active loras into lora memory pool
|
155
141
|
cur_uids = set(forward_batch.lora_paths)
|
156
142
|
assert len(cur_uids) <= self.max_loras_per_batch
|
157
|
-
self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
|
143
|
+
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
|
158
144
|
|
159
145
|
# set up batch info shared by all lora modules
|
160
146
|
bs = forward_batch.batch_size
|
161
147
|
|
148
|
+
def transfer_adapter_info(
|
149
|
+
weight_indices_out: torch.Tensor,
|
150
|
+
lora_ranks_out: torch.Tensor,
|
151
|
+
scalings_out: torch.Tensor,
|
152
|
+
):
|
153
|
+
"""
|
154
|
+
Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
|
155
|
+
to device (CUDA) asynchronously.
|
156
|
+
"""
|
157
|
+
weight_indices = [0] * len(forward_batch.lora_paths)
|
158
|
+
lora_ranks = [0] * self.max_loras_per_batch
|
159
|
+
scalings = [0] * self.max_loras_per_batch
|
160
|
+
for i, lora_path in enumerate(forward_batch.lora_paths):
|
161
|
+
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
162
|
+
if lora_path is not None:
|
163
|
+
lora = self.loras[lora_path]
|
164
|
+
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
|
165
|
+
scalings[weight_indices[i]] = lora.scaling
|
166
|
+
|
167
|
+
# Use pinned memory to avoid synchronizations during host-to-device transfer
|
168
|
+
weight_indices_tensor = torch.tensor(
|
169
|
+
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
|
170
|
+
)
|
171
|
+
lora_ranks_tensor = torch.tensor(
|
172
|
+
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
|
173
|
+
)
|
174
|
+
scalings_tensor = torch.tensor(
|
175
|
+
scalings, dtype=torch.float, pin_memory=True, device="cpu"
|
176
|
+
)
|
177
|
+
|
178
|
+
# Copy to device tensors asynchronously
|
179
|
+
weight_indices_out[:bs].copy_(weight_indices_tensor, non_blocking=True)
|
180
|
+
lora_ranks_out[: self.max_loras_per_batch].copy_(
|
181
|
+
lora_ranks_tensor, non_blocking=True
|
182
|
+
)
|
183
|
+
scalings_out[: self.max_loras_per_batch].copy_(
|
184
|
+
scalings_tensor, non_blocking=True
|
185
|
+
)
|
186
|
+
|
162
187
|
if (
|
163
188
|
hasattr(self, "max_bs_in_cuda_graph")
|
164
189
|
and bs <= self.max_bs_in_cuda_graph
|
@@ -166,51 +191,46 @@ class LoRAManager:
|
|
166
191
|
):
|
167
192
|
# Do in-place updates when CUDA graph is enabled and the batch forward mode
|
168
193
|
# could use CUDA graph.
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
self.cuda_graph_batch_info.
|
173
|
-
|
174
|
-
out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1],
|
194
|
+
|
195
|
+
transfer_adapter_info(
|
196
|
+
self.cuda_graph_batch_info.weight_indices,
|
197
|
+
self.cuda_graph_batch_info.lora_ranks,
|
198
|
+
self.cuda_graph_batch_info.scalings,
|
175
199
|
)
|
176
|
-
self.cuda_graph_batch_info.max_len = 1
|
177
200
|
|
178
|
-
|
179
|
-
|
180
|
-
self.memory_pool.get_buffer_id(lora_path)
|
181
|
-
)
|
182
|
-
if lora_path is not None:
|
183
|
-
lora = self.loras[lora_path]
|
184
|
-
self.cuda_graph_batch_info.lora_ranks[
|
185
|
-
self.cuda_graph_batch_info.weight_indices[i]
|
186
|
-
] = lora.config.hf_config["r"]
|
187
|
-
self.cuda_graph_batch_info.scalings[
|
188
|
-
self.cuda_graph_batch_info.weight_indices[i]
|
189
|
-
] = lora.scaling
|
201
|
+
self.cuda_graph_batch_info.bs = bs
|
202
|
+
self.cuda_graph_batch_info.max_len = 1
|
190
203
|
batch_info = self.cuda_graph_batch_info
|
191
204
|
else:
|
205
|
+
weight_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
206
|
+
lora_ranks = torch.zeros(
|
207
|
+
(self.max_loras_per_batch,), dtype=torch.int64, device=self.device
|
208
|
+
)
|
209
|
+
scalings = torch.zeros(
|
210
|
+
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
|
211
|
+
)
|
212
|
+
transfer_adapter_info(
|
213
|
+
weight_indices,
|
214
|
+
lora_ranks,
|
215
|
+
scalings,
|
216
|
+
)
|
217
|
+
|
192
218
|
seg_lens = (
|
193
219
|
forward_batch.extend_seq_lens
|
194
220
|
if forward_batch.forward_mode.is_extend()
|
195
221
|
else torch.ones(bs, device=self.device)
|
196
222
|
)
|
223
|
+
|
224
|
+
max_len = (
|
225
|
+
# Calculate max_len from the CPU copy to avoid D2H transfer.
|
226
|
+
max(forward_batch.extend_seq_lens_cpu)
|
227
|
+
if forward_batch.forward_mode.is_extend()
|
228
|
+
else 1
|
229
|
+
)
|
230
|
+
|
197
231
|
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
|
198
232
|
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
199
|
-
max_len = int(torch.max(seg_lens))
|
200
|
-
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
201
233
|
|
202
|
-
lora_ranks = torch.zeros(
|
203
|
-
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
|
204
|
-
)
|
205
|
-
scalings = torch.zeros(
|
206
|
-
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
|
207
|
-
)
|
208
|
-
for i, lora_path in enumerate(forward_batch.lora_paths):
|
209
|
-
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
210
|
-
if lora_path is not None:
|
211
|
-
lora = self.loras[lora_path]
|
212
|
-
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
|
213
|
-
scalings[weight_indices[i]] = lora.scaling
|
214
234
|
batch_info = LoRABatchInfo(
|
215
235
|
bs=bs,
|
216
236
|
seg_lens=seg_lens,
|
@@ -222,9 +242,16 @@ class LoRAManager:
|
|
222
242
|
)
|
223
243
|
self.lora_backend.set_batch_info(batch_info)
|
224
244
|
|
225
|
-
#
|
226
|
-
|
227
|
-
|
245
|
+
# TODO (lifuhuang): one potential perf optimization that is worth considering is to see if we can call
|
246
|
+
# this method only when loading/unloading LoRA adapters, instead of calling it for every micro-batch.
|
247
|
+
self.update_lora_info()
|
248
|
+
|
249
|
+
def update_lora_info(self):
|
250
|
+
"""
|
251
|
+
Update all LoRA modules to associate them with the latest memory buffer.
|
252
|
+
"""
|
253
|
+
for layer_id, layer_modules in self.lora_modules.items():
|
254
|
+
for module_name, module in layer_modules.items():
|
228
255
|
if "qkv_proj" in module_name:
|
229
256
|
module.set_lora_info(
|
230
257
|
self.memory_pool.get_tensor(
|
@@ -250,23 +277,139 @@ class LoRAManager:
|
|
250
277
|
),
|
251
278
|
)
|
252
279
|
|
280
|
+
def init_state(self):
|
281
|
+
"""
|
282
|
+
Initialize the internal (mutable) state of the LoRAManager.
|
283
|
+
|
284
|
+
These states are mutable via the `update_state_from_configs` as LoRA adapters are loaded and unloaded dynamically.
|
285
|
+
"""
|
286
|
+
|
287
|
+
# Configs of all active LoRA adapters.
|
288
|
+
self.configs: Dict[str, LoRAConfig] = {}
|
289
|
+
|
290
|
+
# LoRA adapter weights cached in CPU memory.
|
291
|
+
self.loras: Dict[str, LoRAAdapter] = {}
|
292
|
+
|
293
|
+
# Supported weight names (e.g., qkv_proj) for LoRA A and B respectively.
|
294
|
+
self.lora_weight_names: Tuple[Set[str]] = (set(), set())
|
295
|
+
|
296
|
+
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
|
297
|
+
self.lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]] = {
|
298
|
+
i: {} for i in range(self.base_hf_config.num_hidden_layers)
|
299
|
+
}
|
300
|
+
|
301
|
+
# Initialize memory pool
|
302
|
+
self.memory_pool = LoRAMemoryPool(
|
303
|
+
self.base_hf_config,
|
304
|
+
self.max_loras_per_batch,
|
305
|
+
self.dtype,
|
306
|
+
self.tp_size,
|
307
|
+
self.tp_rank,
|
308
|
+
)
|
309
|
+
|
310
|
+
def update_state_from_configs(self):
|
311
|
+
"""
|
312
|
+
Update the internal state of the LoRAManager based on the current `self.configs`. This method
|
313
|
+
should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
|
314
|
+
|
315
|
+
This includes:
|
316
|
+
- Initializing LoRA adapters if they are not already loaded.
|
317
|
+
- Collect all LoRA weight names based on the current loaded adapters.
|
318
|
+
- Lazily monkey-patching the base model to use LoRA layers where applicable.
|
319
|
+
- Preparing the GPU buffer pool for active LoRA weights.
|
320
|
+
"""
|
321
|
+
|
322
|
+
# Target module names in huggingface lora configs.
|
323
|
+
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
|
324
|
+
hf_target_module_names: Set[str] = set()
|
325
|
+
for config in self.configs.values():
|
326
|
+
hf_target_module_names.update(config.target_modules)
|
327
|
+
max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
|
328
|
+
|
329
|
+
# Loads / unloads LoRA adapters based on the latest configs.
|
330
|
+
self.update_lora_adapters()
|
331
|
+
|
332
|
+
# Lazily update states for new LoRA weight name (e.g., qkv_proj) as needed.
|
333
|
+
#
|
334
|
+
# Please note that the following update operations are "monotonic" by design, meaning that we update
|
335
|
+
# multiple places to support the new weight names when the first adapter targeting such weight names
|
336
|
+
# is loaded. However, we never "rollback" the support (e.g., convert LoRA layer back to base layer)
|
337
|
+
# even if the associated adapters are unloaded later for both simplicity and practicality reasons: the
|
338
|
+
# list of LoRA weight names is expected to be extremely finite and stable.
|
339
|
+
self.update_lora_weight_names(hf_target_module_names)
|
340
|
+
self.update_lora_modules(hf_target_module_names)
|
341
|
+
self.update_memory_buffers(max_lora_dim)
|
342
|
+
|
343
|
+
def update_lora_weight_names(self, hf_target_names: Set[str]):
|
344
|
+
"""
|
345
|
+
Add new LoRA weight names if needed based on the current `self.configs`.
|
346
|
+
"""
|
347
|
+
|
348
|
+
# Target lora weight names for lora_a and lora_b modules respectively.
|
349
|
+
for module in hf_target_names:
|
350
|
+
lora_A, lora_B = get_normalized_lora_weight_names(module)
|
351
|
+
self.lora_weight_names[0].update(lora_A)
|
352
|
+
self.lora_weight_names[1].update(lora_B)
|
353
|
+
|
354
|
+
def update_lora_adapters(self):
|
355
|
+
"""
|
356
|
+
Update the LoRA adapters in CPU memory based on the current `self.configs`.
|
357
|
+
It loads any new adapters that are not already loaded, and unloads any adapters
|
358
|
+
that are no longer in `self.configs` (e.g., unloaded).
|
359
|
+
"""
|
360
|
+
|
361
|
+
# Load new adapter weights to cpu
|
362
|
+
for name, config in self.configs.items():
|
363
|
+
if name not in self.loras:
|
364
|
+
logger.info(f"Loading weight of LoRA adapter {name} from {config.path}")
|
365
|
+
lora_adapter = LoRAAdapter(
|
366
|
+
name,
|
367
|
+
config,
|
368
|
+
self.base_hf_config,
|
369
|
+
self.load_config,
|
370
|
+
self.lora_backend,
|
371
|
+
)
|
372
|
+
lora_adapter.initialize_weights()
|
373
|
+
self.loras[name] = lora_adapter
|
374
|
+
|
375
|
+
# Clean up unused LoRA adapters
|
376
|
+
for name in self.loras:
|
377
|
+
if name not in self.configs:
|
378
|
+
logger.info(f"Unloading LoRA adapter {name}")
|
379
|
+
del self.loras[name]
|
380
|
+
|
381
|
+
# Additional checks for flashinfer backend
|
382
|
+
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
383
|
+
if self.lora_backend == "flashinfer":
|
384
|
+
lora_dims = set(x.hf_config["r"] for x in self.configs.values())
|
385
|
+
scalings = set(x.scaling for x in self.loras.values())
|
386
|
+
assert (
|
387
|
+
len(lora_dims) == 1 and len(scalings) == 1
|
388
|
+
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
|
389
|
+
|
390
|
+
def update_memory_buffers(self, max_lora_dim: int):
|
391
|
+
"""
|
392
|
+
Update the LoRA memory pool buffers based on the current LoRA configurations and update
|
393
|
+
LoRA modules to use the new buffers. This method should be called after the LoRA configurations
|
394
|
+
are set or updated.
|
395
|
+
"""
|
396
|
+
|
397
|
+
self.memory_pool.init_buffers(
|
398
|
+
self.lora_weight_names, self.base_model, max_lora_dim
|
399
|
+
)
|
400
|
+
|
253
401
|
def set_lora_module(self, module_name, module):
|
254
402
|
lora_module = get_lora_layer(module, self.lora_backend)
|
255
403
|
replace_submodule(self.base_model, module_name, lora_module)
|
256
404
|
return lora_module
|
257
405
|
|
258
|
-
def
|
406
|
+
def update_lora_modules(self, hf_target_names: Set[str]):
|
259
407
|
# Target module names of customized layers defined in python/sglang/srt/layers
|
260
408
|
# e.g., {"qkv_proj", "o_proj"}
|
261
409
|
customized_target_names = get_customized_names_from_hf_names(
|
262
|
-
|
410
|
+
hf_target_names, self.base_model
|
263
411
|
)
|
264
412
|
|
265
|
-
# Monkey patch to use the LoRA version layers
|
266
|
-
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = {
|
267
|
-
i: [] for i in range(self.base_hf_config.num_hidden_layers)
|
268
|
-
}
|
269
|
-
|
270
413
|
for module_name, module in self.base_model.named_modules():
|
271
414
|
# TODO (lifuhuang): in the future, we should consider generalizing the
|
272
415
|
# should_apply_lora function to support mapping by full module name instead
|
@@ -281,6 +424,7 @@ class LoRAManager:
|
|
281
424
|
# The module should be converted if it is included in target_names
|
282
425
|
if module_name.split(".")[-1] in customized_target_names:
|
283
426
|
layer_id = get_layer_id(module_name)
|
284
|
-
self.lora_modules[layer_id]
|
285
|
-
|
286
|
-
|
427
|
+
if module_name not in self.lora_modules[layer_id]:
|
428
|
+
self.lora_modules[layer_id][module_name] = self.set_lora_module(
|
429
|
+
module_name, module
|
430
|
+
)
|