sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +164 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +62 -23
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +26 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +4 -5
- sglang/srt/managers/data_parallel_controller.py +31 -9
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -31
- sglang/srt/managers/scheduler.py +325 -38
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +27 -8
- sglang/srt/mem_cache/memory_pool.py +258 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +85 -28
- sglang/srt/model_executor/forward_batch_info.py +81 -15
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +326 -192
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +145 -47
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +104 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
sglang/srt/lora/lora_manager.py
CHANGED
@@ -23,7 +23,7 @@ import torch
|
|
23
23
|
from sglang.srt.configs.load_config import LoadConfig
|
24
24
|
from sglang.srt.hf_transformers_utils import AutoConfig
|
25
25
|
from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name
|
26
|
-
from sglang.srt.lora.layers import get_lora_layer
|
26
|
+
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
|
27
27
|
from sglang.srt.lora.lora import LoRAAdapter
|
28
28
|
from sglang.srt.lora.lora_config import LoRAConfig
|
29
29
|
from sglang.srt.lora.mem_pool import LoRAMemoryPool
|
@@ -51,6 +51,8 @@ class LoRAManager:
|
|
51
51
|
load_config: LoadConfig,
|
52
52
|
dtype: torch.dtype,
|
53
53
|
lora_backend: str = "triton",
|
54
|
+
tp_size: int = 1,
|
55
|
+
tp_rank: int = 0,
|
54
56
|
):
|
55
57
|
self.base_model: torch.nn.Module = base_model
|
56
58
|
self.lora_paths: Dict[str, str] = lora_paths
|
@@ -58,6 +60,9 @@ class LoRAManager:
|
|
58
60
|
self.max_loras_per_batch: int = max_loras_per_batch
|
59
61
|
self.load_config: LoadConfig = load_config
|
60
62
|
self.dtype: torch.dtype = dtype
|
63
|
+
self.device: torch.device = next(self.base_model.parameters()).device
|
64
|
+
self.tp_size: int = tp_size
|
65
|
+
self.tp_rank: int = tp_rank
|
61
66
|
|
62
67
|
# LoRA backend for running sgemm kernels
|
63
68
|
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
|
@@ -110,7 +115,13 @@ class LoRAManager:
|
|
110
115
|
def init_lora_memory_pool(self):
|
111
116
|
# Initialize memory pool
|
112
117
|
self.memory_pool = LoRAMemoryPool(
|
113
|
-
self.base_hf_config,
|
118
|
+
self.base_hf_config,
|
119
|
+
self.max_loras_per_batch,
|
120
|
+
self.max_lora_dim,
|
121
|
+
self.dtype,
|
122
|
+
self.tp_size,
|
123
|
+
self.tp_rank,
|
124
|
+
self.lora_modules,
|
114
125
|
)
|
115
126
|
|
116
127
|
# Initialize target lora modules in memory pool
|
@@ -131,12 +142,12 @@ class LoRAManager:
|
|
131
142
|
seg_lens = (
|
132
143
|
forward_batch.extend_seq_lens
|
133
144
|
if forward_batch.forward_mode.is_extend()
|
134
|
-
else torch.ones(bs, device=
|
145
|
+
else torch.ones(bs, device=self.device)
|
135
146
|
)
|
136
|
-
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=
|
147
|
+
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
|
137
148
|
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
138
149
|
max_len = int(torch.max(seg_lens))
|
139
|
-
weight_indices = torch.empty((bs,), dtype=torch.int64, device=
|
150
|
+
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
140
151
|
for i, lora_path in enumerate(forward_batch.lora_paths):
|
141
152
|
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
142
153
|
|
@@ -150,22 +161,32 @@ class LoRAManager:
|
|
150
161
|
self.lora_backend.set_batch_info(batch_info)
|
151
162
|
|
152
163
|
# call set_lora_info for each lora modules
|
153
|
-
for
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
164
|
+
for layer_id, modules in self.lora_modules.items():
|
165
|
+
for module_name, module in modules:
|
166
|
+
if "qkv_proj" in module_name:
|
167
|
+
module.set_lora_info(
|
168
|
+
self.memory_pool.get_tensor(
|
169
|
+
"qkv_proj", layer_id, LoRAType.LORA_A
|
170
|
+
),
|
171
|
+
self.memory_pool.get_tensor(
|
172
|
+
"q_proj", layer_id, LoRAType.LORA_B
|
173
|
+
),
|
174
|
+
self.memory_pool.get_tensor(
|
175
|
+
"kv_proj", layer_id, LoRAType.LORA_B
|
176
|
+
),
|
177
|
+
)
|
178
|
+
else:
|
179
|
+
weight_name = get_weight_name(
|
180
|
+
module_name, self.lora_weight_names, LoRAType.LORA_A
|
181
|
+
)
|
182
|
+
module.set_lora_info(
|
183
|
+
self.memory_pool.get_tensor(
|
184
|
+
weight_name, layer_id, LoRAType.LORA_A
|
185
|
+
),
|
186
|
+
self.memory_pool.get_tensor(
|
187
|
+
weight_name, layer_id, LoRAType.LORA_B
|
188
|
+
),
|
189
|
+
)
|
169
190
|
|
170
191
|
def set_lora_module(self, module_name, module):
|
171
192
|
lora_module = get_lora_layer(
|
@@ -182,10 +203,13 @@ class LoRAManager:
|
|
182
203
|
)
|
183
204
|
|
184
205
|
# Monkey patch to use the LoRA version layers
|
185
|
-
self.lora_modules: List[Tuple[str,
|
206
|
+
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = {
|
207
|
+
i: [] for i in range(self.base_hf_config.num_hidden_layers)
|
208
|
+
}
|
186
209
|
for module_name, module in self.base_model.named_modules():
|
187
210
|
# The module should be converted if it is included in target_names
|
188
211
|
if module_name.split(".")[-1] in customized_target_names:
|
189
|
-
|
212
|
+
layer_id = get_layer_id(module_name)
|
213
|
+
self.lora_modules[layer_id].append(
|
190
214
|
(module_name, self.set_lora_module(module_name, module))
|
191
215
|
)
|
sglang/srt/lora/mem_pool.py
CHANGED
@@ -2,9 +2,12 @@ from typing import Dict, List, Optional, Set, Tuple
|
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
+
from sglang.srt.distributed import divide
|
5
6
|
from sglang.srt.hf_transformers_utils import AutoConfig
|
7
|
+
from sglang.srt.lora.layers import BaseLayerWithLoRA
|
6
8
|
from sglang.srt.lora.lora import LoRAAdapter
|
7
9
|
from sglang.srt.lora.utils import (
|
10
|
+
ROW_PARALLELISM_LINEAR_LORA_NAMES,
|
8
11
|
LoRAType,
|
9
12
|
get_hidden_dim,
|
10
13
|
get_stacked_multiply,
|
@@ -21,6 +24,9 @@ class LoRAMemoryPool:
|
|
21
24
|
max_loras_per_batch: int,
|
22
25
|
max_lora_dim: int,
|
23
26
|
dtype: torch.dtype,
|
27
|
+
tp_size: int,
|
28
|
+
tp_rank: int,
|
29
|
+
lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]],
|
24
30
|
):
|
25
31
|
|
26
32
|
self.base_hf_config: AutoConfig = base_hf_config
|
@@ -28,6 +34,9 @@ class LoRAMemoryPool:
|
|
28
34
|
self.max_loras_per_batch: int = max_loras_per_batch
|
29
35
|
self.max_lora_dim: int = max_lora_dim
|
30
36
|
self.dtype: torch.dtype = dtype
|
37
|
+
self.tp_size: int = tp_size
|
38
|
+
self.tp_rank: int = tp_rank
|
39
|
+
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = lora_modules
|
31
40
|
|
32
41
|
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
33
42
|
# A_buffer contains num_layer number of row-major tensors with shape
|
@@ -45,6 +54,41 @@ class LoRAMemoryPool:
|
|
45
54
|
# Here we don't initalize to None since None is a valid uid
|
46
55
|
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
|
47
56
|
|
57
|
+
def get_lora_A_shape(
|
58
|
+
self, module_name: str, base_model: torch.nn.Module
|
59
|
+
) -> Tuple[int]:
|
60
|
+
"""
|
61
|
+
Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
|
62
|
+
"""
|
63
|
+
input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
64
|
+
c = get_stacked_multiply(module_name)
|
65
|
+
if self.tp_size > 1:
|
66
|
+
if module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
67
|
+
input_dim = divide(input_dim, self.tp_size)
|
68
|
+
return (
|
69
|
+
self.max_loras_per_batch,
|
70
|
+
self.max_lora_dim * c,
|
71
|
+
input_dim,
|
72
|
+
)
|
73
|
+
|
74
|
+
def get_lora_B_shape(
|
75
|
+
self, module_name: str, base_model: torch.nn.Module
|
76
|
+
) -> Tuple[int]:
|
77
|
+
"""
|
78
|
+
Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
|
79
|
+
"""
|
80
|
+
_, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
81
|
+
c = get_stacked_multiply(module_name)
|
82
|
+
if self.tp_size > 1:
|
83
|
+
if module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
84
|
+
output_dim = divide(output_dim, self.tp_size)
|
85
|
+
return (
|
86
|
+
c,
|
87
|
+
self.max_loras_per_batch,
|
88
|
+
output_dim,
|
89
|
+
self.max_lora_dim,
|
90
|
+
)
|
91
|
+
|
48
92
|
def init_buffers(
|
49
93
|
self,
|
50
94
|
lora_weight_names: Set[Tuple[str]],
|
@@ -54,42 +98,31 @@ class LoRAMemoryPool:
|
|
54
98
|
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
|
55
99
|
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
|
56
100
|
self.lora_weight_names: Set[Tuple[str]] = lora_weight_names
|
57
|
-
|
58
|
-
for
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
self.
|
81
|
-
|
82
|
-
(
|
83
|
-
c, # stacked lora_b modules might need separation
|
84
|
-
self.max_loras_per_batch,
|
85
|
-
output_dim,
|
86
|
-
self.max_lora_dim,
|
87
|
-
),
|
88
|
-
dtype=self.dtype,
|
89
|
-
device="cuda",
|
90
|
-
)
|
91
|
-
for i in range(self.num_layer)
|
92
|
-
]
|
101
|
+
device = next(base_model.parameters()).device
|
102
|
+
lora_module_A_names = set([name[0] for name in lora_weight_names])
|
103
|
+
lora_module_B_names = set([name[1] for name in lora_weight_names])
|
104
|
+
# Init A tensor, column_major=False
|
105
|
+
for module_A in lora_module_A_names:
|
106
|
+
lora_A_shape = self.get_lora_A_shape(module_A, base_model)
|
107
|
+
self.A_buffer[module_A] = [
|
108
|
+
torch.empty(
|
109
|
+
lora_A_shape,
|
110
|
+
dtype=self.dtype,
|
111
|
+
device=device,
|
112
|
+
)
|
113
|
+
for i in range(self.num_layer)
|
114
|
+
]
|
115
|
+
# Init B tensor, column_major=True
|
116
|
+
for module_B in lora_module_B_names:
|
117
|
+
lora_B_shape = self.get_lora_B_shape(module_B, base_model)
|
118
|
+
self.B_buffer[module_B] = [
|
119
|
+
torch.empty(
|
120
|
+
lora_B_shape,
|
121
|
+
dtype=self.dtype,
|
122
|
+
device=device,
|
123
|
+
)
|
124
|
+
for _ in range(self.num_layer)
|
125
|
+
]
|
93
126
|
|
94
127
|
def prepare_lora_batch(
|
95
128
|
self,
|
@@ -136,30 +169,56 @@ class LoRAMemoryPool:
|
|
136
169
|
assert lora_adapter is not None
|
137
170
|
for layer_id in range(self.num_layer):
|
138
171
|
layer_weights = lora_adapter.layers[layer_id].weights
|
172
|
+
temp_A_buffer: Dict[str, torch.Tensor] = {}
|
173
|
+
temp_B_buffer: Dict[str, torch.Tensor] = {}
|
139
174
|
for name, weights in layer_weights.items():
|
140
175
|
if "lora_A" in name:
|
141
176
|
lora_weight_name = get_weight_name(
|
142
177
|
name, self.lora_weight_names, LoRAType.LORA_A
|
143
178
|
)
|
144
|
-
|
145
|
-
self.A_buffer[lora_weight_name][layer_id][buffer_id].copy_(
|
146
|
-
weights
|
147
|
-
)
|
179
|
+
temp_A_buffer[lora_weight_name] = weights
|
148
180
|
else:
|
149
181
|
lora_weight_name = get_weight_name(
|
150
182
|
name, self.lora_weight_names, LoRAType.LORA_B
|
151
183
|
)
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
184
|
+
temp_B_buffer[lora_weight_name] = weights
|
185
|
+
|
186
|
+
if self.tp_size > 1:
|
187
|
+
cur_layer_modules = self.lora_modules[layer_id]
|
188
|
+
for module_name, module in cur_layer_modules:
|
189
|
+
if "qkv_proj" in module_name:
|
190
|
+
temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
|
191
|
+
temp_A_buffer["qkv_proj"], self.tp_rank
|
192
|
+
)
|
193
|
+
temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"] = (
|
194
|
+
module.slice_lora_b_weights(
|
195
|
+
[temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"]],
|
196
|
+
self.tp_rank,
|
197
|
+
)
|
198
|
+
)
|
199
|
+
else:
|
200
|
+
weight_name = get_weight_name(
|
201
|
+
module_name, self.lora_weight_names, LoRAType.LORA_A
|
202
|
+
)
|
203
|
+
temp_A_buffer[weight_name] = module.slice_lora_a_weights(
|
204
|
+
temp_A_buffer[weight_name], self.tp_rank
|
205
|
+
)
|
206
|
+
temp_B_buffer[weight_name] = module.slice_lora_b_weights(
|
207
|
+
temp_B_buffer[weight_name], self.tp_rank
|
208
|
+
)
|
209
|
+
|
210
|
+
for name, weights in temp_A_buffer.items():
|
211
|
+
self.A_buffer[name][layer_id][buffer_id].copy_(weights)
|
212
|
+
|
213
|
+
for name, weights in temp_B_buffer.items():
|
214
|
+
c = get_stacked_multiply(name)
|
215
|
+
if c > 1:
|
216
|
+
for stacked_id in range(c):
|
217
|
+
self.B_buffer[name][layer_id][stacked_id][buffer_id].copy_(
|
218
|
+
weights[stacked_id]
|
219
|
+
)
|
220
|
+
else:
|
221
|
+
self.B_buffer[name][layer_id][0][buffer_id].copy_(weights)
|
163
222
|
|
164
223
|
def get_tensor(
|
165
224
|
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
sglang/srt/lora/utils.py
CHANGED
@@ -133,9 +133,20 @@ def get_weight_name(
|
|
133
133
|
target_name is name of a given module,
|
134
134
|
lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
|
135
135
|
If there is a weight name in lora_weight_names that can match target_name, return this name
|
136
|
-
Else
|
136
|
+
Else raise ValueError.
|
137
137
|
"""
|
138
138
|
idx = 0 if lora_type == LoRAType.LORA_A else 1
|
139
139
|
for weight_name_pair in lora_weight_names:
|
140
140
|
if weight_name_pair[idx] in target_name:
|
141
141
|
return weight_name_pair[idx]
|
142
|
+
raise ValueError(
|
143
|
+
f"Cannot find weight name for {target_name} in {lora_weight_names}"
|
144
|
+
)
|
145
|
+
|
146
|
+
|
147
|
+
# TODO: [PR #4274] For future use to simplify the mapping between HF module names and customized module names.
|
148
|
+
VOCAB_PARALLELISM_EMBEDDING_NAMES = ["embeddings"]
|
149
|
+
COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_proj", "up_proj"]
|
150
|
+
MERGED_COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_up_proj"]
|
151
|
+
QKV_PARALLELISM_LINEAR_LORA_NAMES = ["qkv_proj"]
|
152
|
+
ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"]
|
@@ -22,10 +22,7 @@ from typing import List, Optional
|
|
22
22
|
|
23
23
|
import torch
|
24
24
|
|
25
|
-
from sglang.srt.mem_cache.memory_pool import
|
26
|
-
MHATokenToKVPoolHost,
|
27
|
-
TokenToKVPoolAllocator,
|
28
|
-
)
|
25
|
+
from sglang.srt.mem_cache.memory_pool import HostKVCache, TokenToKVPoolAllocator
|
29
26
|
|
30
27
|
logger = logging.getLogger(__name__)
|
31
28
|
|
@@ -151,7 +148,7 @@ class HiCacheController:
|
|
151
148
|
def __init__(
|
152
149
|
self,
|
153
150
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
154
|
-
mem_pool_host:
|
151
|
+
mem_pool_host: HostKVCache,
|
155
152
|
load_cache_event: threading.Event = None,
|
156
153
|
write_policy: str = "write_through_selective",
|
157
154
|
):
|
@@ -248,6 +245,8 @@ class HiCacheController:
|
|
248
245
|
if device_indices is None:
|
249
246
|
return None
|
250
247
|
self.mem_pool_host.protect_load(host_indices)
|
248
|
+
# to ensure the device indices are ready before accessed by another CUDA stream
|
249
|
+
torch.cuda.current_stream().synchronize()
|
251
250
|
self.load_queue.put(
|
252
251
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
253
252
|
)
|
@@ -54,7 +54,7 @@ class LoadBalanceMethod(Enum):
|
|
54
54
|
class DataParallelController:
|
55
55
|
"""A controller that dispatches requests to multiple data parallel workers."""
|
56
56
|
|
57
|
-
def __init__(self, server_args, port_args) -> None:
|
57
|
+
def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
|
58
58
|
# Parse args
|
59
59
|
self.max_total_num_tokens = None
|
60
60
|
self.server_args = server_args
|
@@ -82,10 +82,12 @@ class DataParallelController:
|
|
82
82
|
self.scheduler_procs = []
|
83
83
|
self.workers = [None] * server_args.dp_size
|
84
84
|
|
85
|
-
if
|
86
|
-
dp_port_args = self.launch_dp_schedulers(server_args, port_args)
|
87
|
-
else:
|
85
|
+
if server_args.enable_dp_attention:
|
88
86
|
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
|
87
|
+
self.control_message_step = server_args.tp_size
|
88
|
+
else:
|
89
|
+
dp_port_args = self.launch_dp_schedulers(server_args, port_args)
|
90
|
+
self.control_message_step = 1
|
89
91
|
|
90
92
|
# Only node rank 0 runs the real data parallel controller that dispatches the requests.
|
91
93
|
if server_args.node_rank == 0:
|
@@ -105,6 +107,7 @@ class DataParallelController:
|
|
105
107
|
threads = []
|
106
108
|
sockets = []
|
107
109
|
dp_port_args = []
|
110
|
+
ready_events = []
|
108
111
|
for dp_rank in range(server_args.dp_size):
|
109
112
|
tmp_port_args = PortArgs.init_new(server_args)
|
110
113
|
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
|
@@ -115,10 +118,13 @@ class DataParallelController:
|
|
115
118
|
# We hold it first so that the next dp worker gets a different port
|
116
119
|
sockets.append(bind_port(tmp_port_args.nccl_port))
|
117
120
|
|
121
|
+
ready_event = threading.Event()
|
122
|
+
ready_events.append(ready_event)
|
123
|
+
|
118
124
|
# Create a thread for each worker
|
119
125
|
thread = threading.Thread(
|
120
|
-
target=self.
|
121
|
-
args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
|
126
|
+
target=self.launch_tensor_parallel_group_thread,
|
127
|
+
args=(server_args, tmp_port_args, base_gpu_id, dp_rank, ready_event),
|
122
128
|
)
|
123
129
|
threads.append(thread)
|
124
130
|
base_gpu_id += server_args.tp_size * server_args.gpu_id_step
|
@@ -130,11 +136,27 @@ class DataParallelController:
|
|
130
136
|
# Start all threads
|
131
137
|
for thread in threads:
|
132
138
|
thread.start()
|
133
|
-
for
|
134
|
-
|
139
|
+
for event in ready_events:
|
140
|
+
event.wait()
|
135
141
|
|
136
142
|
return dp_port_args
|
137
143
|
|
144
|
+
def launch_tensor_parallel_group_thread(
|
145
|
+
self,
|
146
|
+
server_args: ServerArgs,
|
147
|
+
port_args: PortArgs,
|
148
|
+
base_gpu_id: int,
|
149
|
+
dp_rank: int,
|
150
|
+
ready_event: threading.Event,
|
151
|
+
):
|
152
|
+
self.launch_tensor_parallel_group(server_args, port_args, base_gpu_id, dp_rank)
|
153
|
+
ready_event.set()
|
154
|
+
|
155
|
+
# This thread cannot be closed because otherwise the `kill_itself_when_parent_died`
|
156
|
+
# function in scheduler.py will kill the scheduler.
|
157
|
+
while True:
|
158
|
+
pass
|
159
|
+
|
138
160
|
def launch_dp_attention_schedulers(self, server_args, port_args):
|
139
161
|
self.launch_tensor_parallel_group(server_args, port_args, 0, None)
|
140
162
|
dp_port_args = []
|
@@ -223,7 +245,7 @@ class DataParallelController:
|
|
223
245
|
self.dispatching(recv_req)
|
224
246
|
else:
|
225
247
|
# Send other control messages to first worker of tp group
|
226
|
-
for worker in self.workers[:: self.
|
248
|
+
for worker in self.workers[:: self.control_message_step]:
|
227
249
|
worker.send_pyobj(recv_req)
|
228
250
|
|
229
251
|
|
@@ -0,0 +1,81 @@
|
|
1
|
+
import json
|
2
|
+
import logging
|
3
|
+
import time
|
4
|
+
from collections import defaultdict
|
5
|
+
from typing import Dict, List, Tuple
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
|
12
|
+
# global expert distribution recording
|
13
|
+
class ExpertDistributionRecorder:
|
14
|
+
# This class is a singleton class
|
15
|
+
def __new__(cls):
|
16
|
+
if not hasattr(cls, "instance"):
|
17
|
+
cls.instance = super(ExpertDistributionRecorder, cls).__new__(cls)
|
18
|
+
return cls.instance
|
19
|
+
|
20
|
+
def __init__(self):
|
21
|
+
# the length of the dictionary is the number of layers
|
22
|
+
# the length of the list is the number of tokens
|
23
|
+
# the length of the tuple is topk's k value
|
24
|
+
self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict(
|
25
|
+
list
|
26
|
+
)
|
27
|
+
self._record = False
|
28
|
+
self._current_layer_id = "UNKNOWN"
|
29
|
+
|
30
|
+
def set_current_layer(self, layer_idx):
|
31
|
+
self._current_layer_id = layer_idx
|
32
|
+
|
33
|
+
def record_new_token(self, topk_ids):
|
34
|
+
if not self._record:
|
35
|
+
return
|
36
|
+
topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
|
37
|
+
torch.cuda.synchronize()
|
38
|
+
for i in topk_ids_list:
|
39
|
+
self._expert_distribution_record[self._current_layer_id].append(tuple(i))
|
40
|
+
|
41
|
+
def reset(self):
|
42
|
+
"""Reset the expert distribution recorder."""
|
43
|
+
logger.info("Resetting expert distribution record...")
|
44
|
+
self._record = False
|
45
|
+
self._expert_distribution_record.clear()
|
46
|
+
self._current_layer_id = "UNKNOWN"
|
47
|
+
|
48
|
+
def start_record(self):
|
49
|
+
"""Start recording the expert distribution. Reset the recorder and set the recording flag to True."""
|
50
|
+
if self._record == True:
|
51
|
+
logger.warning(
|
52
|
+
"SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?"
|
53
|
+
)
|
54
|
+
self.reset()
|
55
|
+
self._record = True
|
56
|
+
|
57
|
+
def stop_record(self):
|
58
|
+
"""Stop recording the expert distribution. Set the recording flag to False."""
|
59
|
+
if self._record == False:
|
60
|
+
logger.warning(
|
61
|
+
"SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?"
|
62
|
+
)
|
63
|
+
self._record = False
|
64
|
+
|
65
|
+
def dump_record(self):
|
66
|
+
"""Dump the expert distribution record to a file. Reset the recorder after dumping."""
|
67
|
+
results = {}
|
68
|
+
for layer_idx, layer_record in self._expert_distribution_record.items():
|
69
|
+
results[layer_idx] = defaultdict(int)
|
70
|
+
for token_record in layer_record:
|
71
|
+
for expert_idx in token_record:
|
72
|
+
results[layer_idx][expert_idx] += 1
|
73
|
+
with open(
|
74
|
+
f"expert_distribution_rank{torch.distributed.get_rank()}_timestamp{time.time()}.csv",
|
75
|
+
"w",
|
76
|
+
) as fd:
|
77
|
+
fd.write("layer_id,expert_id,count\n")
|
78
|
+
for layer_idx, layer_results in results.items():
|
79
|
+
for expert_idx, count in layer_results.items():
|
80
|
+
fd.write(f"{layer_idx},{expert_idx},{count}\n")
|
81
|
+
self.reset()
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -45,6 +45,8 @@ class GenerateReqInput:
|
|
45
45
|
# The image input. It can be a file name, a url, or base64 encoded string.
|
46
46
|
# See also python/sglang/srt/utils.py:load_image.
|
47
47
|
image_data: Optional[Union[List[str], str]] = None
|
48
|
+
# The audio input. Like image data, tt can be a file name, a url, or base64 encoded string.
|
49
|
+
audio_data: Optional[Union[List[str], str]] = None
|
48
50
|
# The sampling_params. See descriptions below.
|
49
51
|
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
50
52
|
# The request id.
|
@@ -103,6 +105,8 @@ class GenerateReqInput:
|
|
103
105
|
self.batch_size = len(self.text)
|
104
106
|
self.input_embeds = None
|
105
107
|
elif self.input_ids is not None:
|
108
|
+
if len(self.input_ids) == 0:
|
109
|
+
raise ValueError("input_ids cannot be empty.")
|
106
110
|
if isinstance(self.input_ids[0], int):
|
107
111
|
self.is_single = True
|
108
112
|
self.batch_size = 1
|
@@ -165,6 +169,13 @@ class GenerateReqInput:
|
|
165
169
|
elif isinstance(self.image_data, list):
|
166
170
|
pass
|
167
171
|
|
172
|
+
if self.audio_data is None:
|
173
|
+
self.audio_data = [None] * num
|
174
|
+
elif not isinstance(self.audio_data, list):
|
175
|
+
self.audio_data = [self.audio_data] * num
|
176
|
+
elif isinstance(self.audio_data, list):
|
177
|
+
pass
|
178
|
+
|
168
179
|
if self.sampling_params is None:
|
169
180
|
self.sampling_params = [{}] * num
|
170
181
|
elif not isinstance(self.sampling_params, list):
|
@@ -229,6 +240,7 @@ class GenerateReqInput:
|
|
229
240
|
text=self.text[i] if self.text is not None else None,
|
230
241
|
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
231
242
|
image_data=self.image_data[i],
|
243
|
+
audio_data=self.audio_data[i],
|
232
244
|
sampling_params=self.sampling_params[i],
|
233
245
|
rid=self.rid[i],
|
234
246
|
return_logprob=self.return_logprob[i],
|
@@ -257,8 +269,8 @@ class TokenizedGenerateReqInput:
|
|
257
269
|
input_text: str
|
258
270
|
# The input token ids
|
259
271
|
input_ids: List[int]
|
260
|
-
# The
|
261
|
-
|
272
|
+
# The multimodal inputs
|
273
|
+
mm_inputs: dict
|
262
274
|
# The sampling parameters
|
263
275
|
sampling_params: SamplingParams
|
264
276
|
# Whether to return the logprobs
|
@@ -538,7 +550,8 @@ class UpdateWeightsFromDistributedReqOutput:
|
|
538
550
|
|
539
551
|
@dataclass
|
540
552
|
class UpdateWeightsFromTensorReqInput:
|
541
|
-
|
553
|
+
# List containing one serialized Dict[str, torch.Tensor] per TP worker
|
554
|
+
serialized_named_tensors: List[bytes]
|
542
555
|
load_format: Optional[str]
|
543
556
|
flush_cache: bool
|
544
557
|
|
@@ -645,6 +658,17 @@ class ProfileReqType(Enum):
|
|
645
658
|
STOP_PROFILE = 2
|
646
659
|
|
647
660
|
|
661
|
+
class ExpertDistributionReq(Enum):
|
662
|
+
START_RECORD = 1
|
663
|
+
STOP_RECORD = 2
|
664
|
+
DUMP_RECORD = 3
|
665
|
+
|
666
|
+
|
667
|
+
@dataclass
|
668
|
+
class ExpertDistributionReqOutput:
|
669
|
+
pass
|
670
|
+
|
671
|
+
|
648
672
|
@dataclass
|
649
673
|
class ProfileReq:
|
650
674
|
type: ProfileReqType
|
@@ -723,3 +747,15 @@ class SeparateReasoningReqInput:
|
|
723
747
|
class VertexGenerateReqInput:
|
724
748
|
instances: List[dict]
|
725
749
|
parameters: Optional[dict] = None
|
750
|
+
|
751
|
+
|
752
|
+
@dataclass
|
753
|
+
class RpcReqInput:
|
754
|
+
method: str
|
755
|
+
parameters: Optional[Dict] = None
|
756
|
+
|
757
|
+
|
758
|
+
@dataclass
|
759
|
+
class RpcReqOutput:
|
760
|
+
success: bool
|
761
|
+
message: str
|