sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +26 -4
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +676 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +49 -8
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +42 -8
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +78 -13
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +133 -55
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +434 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +41 -19
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +60 -20
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +80 -53
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +25 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -19
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +78 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +87 -33
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +67 -30
- sglang/srt/lora/mem_pool.py +117 -52
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +18 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +43 -5
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -30
- sglang/srt/managers/scheduler.py +290 -31
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -24
- sglang/srt/managers/tp_worker.py +4 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +36 -21
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +75 -8
- sglang/srt/model_loader/loader.py +171 -3
- sglang/srt/model_loader/weight_utils.py +51 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +31 -88
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +329 -73
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +694 -0
- sglang/srt/models/gemma3_mm.py +468 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +201 -104
- sglang/srt/openai_api/protocol.py +33 -7
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +114 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +140 -54
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +215 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +29 -2
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +56 -5
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
sglang/srt/lora/mem_pool.py
CHANGED
@@ -2,9 +2,12 @@ from typing import Dict, List, Optional, Set, Tuple
|
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
+
from sglang.srt.distributed import divide
|
5
6
|
from sglang.srt.hf_transformers_utils import AutoConfig
|
7
|
+
from sglang.srt.lora.layers import BaseLayerWithLoRA
|
6
8
|
from sglang.srt.lora.lora import LoRAAdapter
|
7
9
|
from sglang.srt.lora.utils import (
|
10
|
+
ROW_PARALLELISM_LINEAR_LORA_NAMES,
|
8
11
|
LoRAType,
|
9
12
|
get_hidden_dim,
|
10
13
|
get_stacked_multiply,
|
@@ -21,6 +24,9 @@ class LoRAMemoryPool:
|
|
21
24
|
max_loras_per_batch: int,
|
22
25
|
max_lora_dim: int,
|
23
26
|
dtype: torch.dtype,
|
27
|
+
tp_size: int,
|
28
|
+
tp_rank: int,
|
29
|
+
lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]],
|
24
30
|
):
|
25
31
|
|
26
32
|
self.base_hf_config: AutoConfig = base_hf_config
|
@@ -28,6 +34,9 @@ class LoRAMemoryPool:
|
|
28
34
|
self.max_loras_per_batch: int = max_loras_per_batch
|
29
35
|
self.max_lora_dim: int = max_lora_dim
|
30
36
|
self.dtype: torch.dtype = dtype
|
37
|
+
self.tp_size: int = tp_size
|
38
|
+
self.tp_rank: int = tp_rank
|
39
|
+
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = lora_modules
|
31
40
|
|
32
41
|
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
33
42
|
# A_buffer contains num_layer number of row-major tensors with shape
|
@@ -45,6 +54,41 @@ class LoRAMemoryPool:
|
|
45
54
|
# Here we don't initalize to None since None is a valid uid
|
46
55
|
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
|
47
56
|
|
57
|
+
def get_lora_A_shape(
|
58
|
+
self, module_name: str, base_model: torch.nn.Module
|
59
|
+
) -> Tuple[int]:
|
60
|
+
"""
|
61
|
+
Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
|
62
|
+
"""
|
63
|
+
input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
64
|
+
c = get_stacked_multiply(module_name)
|
65
|
+
if self.tp_size > 1:
|
66
|
+
if module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
67
|
+
input_dim = divide(input_dim, self.tp_size)
|
68
|
+
return (
|
69
|
+
self.max_loras_per_batch,
|
70
|
+
self.max_lora_dim * c,
|
71
|
+
input_dim,
|
72
|
+
)
|
73
|
+
|
74
|
+
def get_lora_B_shape(
|
75
|
+
self, module_name: str, base_model: torch.nn.Module
|
76
|
+
) -> Tuple[int]:
|
77
|
+
"""
|
78
|
+
Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
|
79
|
+
"""
|
80
|
+
_, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
81
|
+
c = get_stacked_multiply(module_name)
|
82
|
+
if self.tp_size > 1:
|
83
|
+
if module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
84
|
+
output_dim = divide(output_dim, self.tp_size)
|
85
|
+
return (
|
86
|
+
c,
|
87
|
+
self.max_loras_per_batch,
|
88
|
+
output_dim,
|
89
|
+
self.max_lora_dim,
|
90
|
+
)
|
91
|
+
|
48
92
|
def init_buffers(
|
49
93
|
self,
|
50
94
|
lora_weight_names: Set[Tuple[str]],
|
@@ -54,42 +98,31 @@ class LoRAMemoryPool:
|
|
54
98
|
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
|
55
99
|
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
|
56
100
|
self.lora_weight_names: Set[Tuple[str]] = lora_weight_names
|
57
|
-
|
58
|
-
for
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
self.
|
81
|
-
|
82
|
-
(
|
83
|
-
c, # stacked lora_b modules might need separation
|
84
|
-
self.max_loras_per_batch,
|
85
|
-
output_dim,
|
86
|
-
self.max_lora_dim,
|
87
|
-
),
|
88
|
-
dtype=self.dtype,
|
89
|
-
device="cuda",
|
90
|
-
)
|
91
|
-
for i in range(self.num_layer)
|
92
|
-
]
|
101
|
+
device = next(base_model.parameters()).device
|
102
|
+
lora_module_A_names = set([name[0] for name in lora_weight_names])
|
103
|
+
lora_module_B_names = set([name[1] for name in lora_weight_names])
|
104
|
+
# Init A tensor, column_major=False
|
105
|
+
for module_A in lora_module_A_names:
|
106
|
+
lora_A_shape = self.get_lora_A_shape(module_A, base_model)
|
107
|
+
self.A_buffer[module_A] = [
|
108
|
+
torch.empty(
|
109
|
+
lora_A_shape,
|
110
|
+
dtype=self.dtype,
|
111
|
+
device=device,
|
112
|
+
)
|
113
|
+
for i in range(self.num_layer)
|
114
|
+
]
|
115
|
+
# Init B tensor, column_major=True
|
116
|
+
for module_B in lora_module_B_names:
|
117
|
+
lora_B_shape = self.get_lora_B_shape(module_B, base_model)
|
118
|
+
self.B_buffer[module_B] = [
|
119
|
+
torch.empty(
|
120
|
+
lora_B_shape,
|
121
|
+
dtype=self.dtype,
|
122
|
+
device=device,
|
123
|
+
)
|
124
|
+
for _ in range(self.num_layer)
|
125
|
+
]
|
93
126
|
|
94
127
|
def prepare_lora_batch(
|
95
128
|
self,
|
@@ -130,36 +163,68 @@ class LoRAMemoryPool:
|
|
130
163
|
if uid is None:
|
131
164
|
for i in range(self.num_layer):
|
132
165
|
for k in self.A_buffer.keys():
|
133
|
-
self.A_buffer[k][i][buffer_id]
|
166
|
+
self.A_buffer[k][i][buffer_id] = 0
|
134
167
|
return
|
135
168
|
|
136
169
|
assert lora_adapter is not None
|
170
|
+
lora_rank = lora_adapter.config.hf_config["r"]
|
137
171
|
for layer_id in range(self.num_layer):
|
138
172
|
layer_weights = lora_adapter.layers[layer_id].weights
|
173
|
+
temp_A_buffer: Dict[str, torch.Tensor] = {}
|
174
|
+
temp_B_buffer: Dict[str, torch.Tensor] = {}
|
139
175
|
for name, weights in layer_weights.items():
|
140
176
|
if "lora_A" in name:
|
141
177
|
lora_weight_name = get_weight_name(
|
142
178
|
name, self.lora_weight_names, LoRAType.LORA_A
|
143
179
|
)
|
144
|
-
|
145
|
-
self.A_buffer[lora_weight_name][layer_id][buffer_id].copy_(
|
146
|
-
weights
|
147
|
-
)
|
180
|
+
temp_A_buffer[lora_weight_name] = weights
|
148
181
|
else:
|
149
182
|
lora_weight_name = get_weight_name(
|
150
183
|
name, self.lora_weight_names, LoRAType.LORA_B
|
151
184
|
)
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
185
|
+
temp_B_buffer[lora_weight_name] = weights
|
186
|
+
|
187
|
+
if self.tp_size > 1:
|
188
|
+
cur_layer_modules = self.lora_modules[layer_id]
|
189
|
+
for module_name, module in cur_layer_modules:
|
190
|
+
if "qkv_proj" in module_name:
|
191
|
+
temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
|
192
|
+
temp_A_buffer["qkv_proj"], self.tp_rank
|
193
|
+
)
|
194
|
+
temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"] = (
|
195
|
+
module.slice_lora_b_weights(
|
196
|
+
[temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"]],
|
197
|
+
self.tp_rank,
|
198
|
+
)
|
199
|
+
)
|
200
|
+
else:
|
201
|
+
weight_name = get_weight_name(
|
202
|
+
module_name, self.lora_weight_names, LoRAType.LORA_A
|
203
|
+
)
|
204
|
+
temp_A_buffer[weight_name] = module.slice_lora_a_weights(
|
205
|
+
temp_A_buffer[weight_name], self.tp_rank
|
206
|
+
)
|
207
|
+
temp_B_buffer[weight_name] = module.slice_lora_b_weights(
|
208
|
+
temp_B_buffer[weight_name], self.tp_rank
|
209
|
+
)
|
210
|
+
|
211
|
+
for name, weights in temp_A_buffer.items():
|
212
|
+
c = get_stacked_multiply(name)
|
213
|
+
self.A_buffer[name][layer_id][buffer_id][: lora_rank * c, :].copy_(
|
214
|
+
weights
|
215
|
+
)
|
216
|
+
|
217
|
+
for name, weights in temp_B_buffer.items():
|
218
|
+
c = get_stacked_multiply(name)
|
219
|
+
if c > 1:
|
220
|
+
for stacked_id in range(c):
|
221
|
+
self.B_buffer[name][layer_id][stacked_id][buffer_id][
|
222
|
+
:, :lora_rank
|
223
|
+
].copy_(weights[stacked_id])
|
224
|
+
else:
|
225
|
+
self.B_buffer[name][layer_id][0][buffer_id][:, :lora_rank].copy_(
|
226
|
+
weights
|
227
|
+
)
|
163
228
|
|
164
229
|
def get_tensor(
|
165
230
|
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
@@ -22,17 +22,18 @@ def _gate_up_lora_b_kernel(
|
|
22
22
|
w_stride_2,
|
23
23
|
output_stride_0,
|
24
24
|
output_stride_1,
|
25
|
-
# Information on sequence lengths and weight id
|
25
|
+
# Information on sequence lengths,ranks and weight id
|
26
26
|
seg_lens,
|
27
27
|
seg_indptr,
|
28
28
|
weight_indices,
|
29
|
+
lora_ranks,
|
29
30
|
# Meta parameters
|
30
31
|
BLOCK_S: tl.constexpr,
|
31
32
|
BLOCK_N: tl.constexpr,
|
32
33
|
BLOCK_K: tl.constexpr,
|
33
34
|
# For fused output scaling and adding
|
34
35
|
fuse_scaling_add,
|
35
|
-
|
36
|
+
scalings,
|
36
37
|
):
|
37
38
|
# This kernel packs 2 sgemms (gate/up) into a single kernel.
|
38
39
|
|
@@ -51,6 +52,11 @@ def _gate_up_lora_b_kernel(
|
|
51
52
|
w_index = tl.load(weight_indices + batch_id)
|
52
53
|
seg_start = tl.load(seg_indptr + batch_id)
|
53
54
|
n_start = gate_up_id * output_dim # offset on output dim
|
55
|
+
rank = tl.load(lora_ranks + w_index)
|
56
|
+
scaling = tl.load(scalings + w_index)
|
57
|
+
|
58
|
+
# Adjust K (rank) according to the specific LoRA adapter
|
59
|
+
K = tl.minimum(K, rank)
|
54
60
|
|
55
61
|
# The tile in output matrix will have (pid_s, pid_n) as id
|
56
62
|
num_pid_n = tl.cdiv(output_dim, BLOCK_N)
|
@@ -109,7 +115,6 @@ def gate_up_lora_b_fwd(
|
|
109
115
|
batch_info: LoRABatchInfo,
|
110
116
|
output_dim: int,
|
111
117
|
base_output: torch.Tensor = None,
|
112
|
-
scaling: float = 1.0,
|
113
118
|
) -> torch.Tensor:
|
114
119
|
|
115
120
|
# x: (s, 2 * r)
|
@@ -160,11 +165,12 @@ def gate_up_lora_b_fwd(
|
|
160
165
|
batch_info.seg_lens,
|
161
166
|
batch_info.seg_indptr,
|
162
167
|
batch_info.weight_indices,
|
168
|
+
batch_info.lora_ranks,
|
163
169
|
BLOCK_S,
|
164
170
|
BLOCK_OUT,
|
165
171
|
BLOCK_R,
|
166
172
|
fuse_scaling_add,
|
167
|
-
|
173
|
+
batch_info.scalings,
|
168
174
|
)
|
169
175
|
|
170
176
|
return output
|
@@ -26,6 +26,7 @@ def _qkv_lora_b_kernel(
|
|
26
26
|
seg_lens,
|
27
27
|
seg_indptr,
|
28
28
|
weight_indices,
|
29
|
+
lora_ranks,
|
29
30
|
# Offsets of q/k/v slice on output dimension
|
30
31
|
n_offs,
|
31
32
|
# Meta parameters
|
@@ -34,7 +35,7 @@ def _qkv_lora_b_kernel(
|
|
34
35
|
BLOCK_K: tl.constexpr,
|
35
36
|
# For fused output scaling and adding
|
36
37
|
fuse_scaling_add,
|
37
|
-
|
38
|
+
scalings,
|
38
39
|
):
|
39
40
|
# This kernel packs 3 sgemms (q/k/v) into a single kernel.
|
40
41
|
|
@@ -54,6 +55,10 @@ def _qkv_lora_b_kernel(
|
|
54
55
|
seg_start = tl.load(seg_indptr + batch_id)
|
55
56
|
n_start = tl.load(n_offs + qkv_id)
|
56
57
|
n_size = tl.load(n_offs + qkv_id + 1) - n_start
|
58
|
+
rank = tl.load(lora_ranks + w_index)
|
59
|
+
scaling = tl.load(scalings + w_index)
|
60
|
+
# Adjust K (rank) according to the specific LoRA adapter
|
61
|
+
K = tl.minimum(K, rank)
|
57
62
|
|
58
63
|
# The tile in output matrix will have (pid_s, pid_n) as id
|
59
64
|
num_pid_n = tl.cdiv(max_qkv_out_dim, BLOCK_N)
|
@@ -112,7 +117,6 @@ def qkv_lora_b_fwd(
|
|
112
117
|
output_offset: torch.Tensor,
|
113
118
|
max_qkv_out_dim: int,
|
114
119
|
base_output: torch.Tensor = None,
|
115
|
-
scaling: float = 1.0,
|
116
120
|
) -> torch.Tensor:
|
117
121
|
|
118
122
|
# x: (s, 3 * r)
|
@@ -171,12 +175,13 @@ def qkv_lora_b_fwd(
|
|
171
175
|
batch_info.seg_lens,
|
172
176
|
batch_info.seg_indptr,
|
173
177
|
batch_info.weight_indices,
|
178
|
+
batch_info.lora_ranks,
|
174
179
|
output_offset,
|
175
180
|
BLOCK_S,
|
176
181
|
BLOCK_OUT,
|
177
182
|
BLOCK_R,
|
178
183
|
fuse_scaling_add,
|
179
|
-
|
184
|
+
batch_info.scalings,
|
180
185
|
)
|
181
186
|
|
182
187
|
return output
|
@@ -12,8 +12,9 @@ def _sgemm_lora_a_kernel(
|
|
12
12
|
weights,
|
13
13
|
output,
|
14
14
|
# Matrix dimensions
|
15
|
-
N, # r
|
15
|
+
N, # stack_num * r
|
16
16
|
K, # input_dim
|
17
|
+
stack_num,
|
17
18
|
# Strides
|
18
19
|
x_stride_0,
|
19
20
|
x_stride_1,
|
@@ -22,10 +23,11 @@ def _sgemm_lora_a_kernel(
|
|
22
23
|
w_stride_2,
|
23
24
|
output_stride_0,
|
24
25
|
output_stride_1,
|
25
|
-
# Information on sequence lengths and weight id
|
26
|
+
# Information on sequence lengths,ranks and weight id
|
26
27
|
seg_lens,
|
27
28
|
seg_indptr,
|
28
29
|
weight_indices,
|
30
|
+
lora_ranks,
|
29
31
|
# Meta parameters
|
30
32
|
BLOCK_S: tl.constexpr,
|
31
33
|
BLOCK_N: tl.constexpr,
|
@@ -43,6 +45,9 @@ def _sgemm_lora_a_kernel(
|
|
43
45
|
seg_len = tl.load(seg_lens + batch_id)
|
44
46
|
w_index = tl.load(weight_indices + batch_id)
|
45
47
|
seg_start = tl.load(seg_indptr + batch_id)
|
48
|
+
rank = tl.load(lora_ranks + w_index)
|
49
|
+
# Adjust N (stack_num * max_rank) according to the specific LoRA adapter
|
50
|
+
N = tl.minimum(N, rank * stack_num)
|
46
51
|
|
47
52
|
# The tile in output matrix will have (pid_s, pid_n) as id
|
48
53
|
num_pid_n = tl.cdiv(N, BLOCK_N)
|
@@ -91,11 +96,15 @@ def _sgemm_lora_a_kernel(
|
|
91
96
|
|
92
97
|
|
93
98
|
def sgemm_lora_a_fwd(
|
94
|
-
x: torch.Tensor,
|
99
|
+
x: torch.Tensor,
|
100
|
+
weights: torch.Tensor,
|
101
|
+
batch_info: LoRABatchInfo,
|
102
|
+
stack_num: int = 1,
|
95
103
|
) -> torch.Tensor:
|
96
104
|
# x: (s, input_dim)
|
97
|
-
# weights: (num_lora, r, input_dim)
|
98
|
-
# output: (s, r)
|
105
|
+
# weights: (num_lora, stack_num * r, input_dim)
|
106
|
+
# output: (s, stack_num * r)
|
107
|
+
# stack_num: run_qkv_lora: 3, run_gate_up_lora: 2
|
99
108
|
# when called by run_qkv_lora, the weights.shape[-2] will be 3 * r
|
100
109
|
# input_dim is much larger than r
|
101
110
|
|
@@ -126,6 +135,7 @@ def sgemm_lora_a_fwd(
|
|
126
135
|
output,
|
127
136
|
R,
|
128
137
|
K,
|
138
|
+
stack_num,
|
129
139
|
x.stride(0),
|
130
140
|
x.stride(1),
|
131
141
|
weights.stride(0),
|
@@ -136,6 +146,7 @@ def sgemm_lora_a_fwd(
|
|
136
146
|
batch_info.seg_lens,
|
137
147
|
batch_info.seg_indptr,
|
138
148
|
batch_info.weight_indices,
|
149
|
+
batch_info.lora_ranks,
|
139
150
|
BLOCK_S,
|
140
151
|
BLOCK_R,
|
141
152
|
BLOCK_K,
|
@@ -26,13 +26,14 @@ def _sgemm_lora_b_kernel(
|
|
26
26
|
seg_lens,
|
27
27
|
seg_indptr,
|
28
28
|
weight_indices,
|
29
|
+
lora_ranks,
|
29
30
|
# Meta parameters
|
30
31
|
BLOCK_S: tl.constexpr,
|
31
32
|
BLOCK_N: tl.constexpr,
|
32
33
|
BLOCK_K: tl.constexpr,
|
33
34
|
# For fused output scaling and adding
|
34
35
|
fuse_scaling_add,
|
35
|
-
|
36
|
+
scalings,
|
36
37
|
):
|
37
38
|
# x: (s, K), s is the sum of sequence lengths
|
38
39
|
# weights: (num_lora, N, K)
|
@@ -45,6 +46,10 @@ def _sgemm_lora_b_kernel(
|
|
45
46
|
seg_len = tl.load(seg_lens + batch_id)
|
46
47
|
w_index = tl.load(weight_indices + batch_id)
|
47
48
|
seg_start = tl.load(seg_indptr + batch_id)
|
49
|
+
rank = tl.load(lora_ranks + w_index)
|
50
|
+
scaling = tl.load(scalings + w_index)
|
51
|
+
# Adjust K (rank) according to the specific LoRA adapter
|
52
|
+
K = tl.minimum(K, rank)
|
48
53
|
|
49
54
|
# The tile in output matrix will have (pid_s, pid_n) as id
|
50
55
|
num_pid_n = tl.cdiv(N, BLOCK_N)
|
@@ -100,12 +105,11 @@ def sgemm_lora_b_fwd(
|
|
100
105
|
weights: torch.Tensor,
|
101
106
|
batch_info: LoRABatchInfo,
|
102
107
|
base_output: torch.Tensor = None,
|
103
|
-
scaling: float = 1.0,
|
104
108
|
) -> torch.Tensor:
|
105
|
-
# x: (s,
|
106
|
-
# weights: (num_lora, output_dim,
|
109
|
+
# x: (s, max_r)
|
110
|
+
# weights: (num_lora, output_dim, max_r)
|
107
111
|
# output: (s, output_dim)
|
108
|
-
# output_dim is much larger than
|
112
|
+
# output_dim is much larger than max_r
|
109
113
|
|
110
114
|
assert x.is_contiguous()
|
111
115
|
assert weights.is_contiguous()
|
@@ -150,10 +154,11 @@ def sgemm_lora_b_fwd(
|
|
150
154
|
batch_info.seg_lens,
|
151
155
|
batch_info.seg_indptr,
|
152
156
|
batch_info.weight_indices,
|
157
|
+
batch_info.lora_ranks,
|
153
158
|
BLOCK_S,
|
154
159
|
BLOCK_N,
|
155
160
|
BLOCK_R,
|
156
161
|
fuse_scaling_add,
|
157
|
-
|
162
|
+
batch_info.scalings,
|
158
163
|
)
|
159
164
|
return output
|
sglang/srt/lora/utils.py
CHANGED
@@ -25,6 +25,12 @@ class LoRABatchInfo:
|
|
25
25
|
# The index of lora adapter used by each sequence, in shape (bs,)
|
26
26
|
weight_indices: torch.Tensor
|
27
27
|
|
28
|
+
# ranks of each lora adapter, in shape (lora_num,)
|
29
|
+
lora_ranks: torch.Tensor
|
30
|
+
|
31
|
+
# scaling of each lora adapter, in shape (lora_num,)
|
32
|
+
scalings: torch.Tensor
|
33
|
+
|
28
34
|
|
29
35
|
class LoRAType(Enum):
|
30
36
|
LORA_A = 0
|
@@ -133,9 +139,20 @@ def get_weight_name(
|
|
133
139
|
target_name is name of a given module,
|
134
140
|
lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
|
135
141
|
If there is a weight name in lora_weight_names that can match target_name, return this name
|
136
|
-
Else
|
142
|
+
Else raise ValueError.
|
137
143
|
"""
|
138
144
|
idx = 0 if lora_type == LoRAType.LORA_A else 1
|
139
145
|
for weight_name_pair in lora_weight_names:
|
140
146
|
if weight_name_pair[idx] in target_name:
|
141
147
|
return weight_name_pair[idx]
|
148
|
+
raise ValueError(
|
149
|
+
f"Cannot find weight name for {target_name} in {lora_weight_names}"
|
150
|
+
)
|
151
|
+
|
152
|
+
|
153
|
+
# TODO: [PR #4274] For future use to simplify the mapping between HF module names and customized module names.
|
154
|
+
VOCAB_PARALLELISM_EMBEDDING_NAMES = ["embeddings"]
|
155
|
+
COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_proj", "up_proj"]
|
156
|
+
MERGED_COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_up_proj"]
|
157
|
+
QKV_PARALLELISM_LINEAR_LORA_NAMES = ["qkv_proj"]
|
158
|
+
ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"]
|
@@ -22,10 +22,7 @@ from typing import List, Optional
|
|
22
22
|
|
23
23
|
import torch
|
24
24
|
|
25
|
-
from sglang.srt.mem_cache.memory_pool import
|
26
|
-
MHATokenToKVPoolHost,
|
27
|
-
TokenToKVPoolAllocator,
|
28
|
-
)
|
25
|
+
from sglang.srt.mem_cache.memory_pool import HostKVCache, TokenToKVPoolAllocator
|
29
26
|
|
30
27
|
logger = logging.getLogger(__name__)
|
31
28
|
|
@@ -151,7 +148,7 @@ class HiCacheController:
|
|
151
148
|
def __init__(
|
152
149
|
self,
|
153
150
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
154
|
-
mem_pool_host:
|
151
|
+
mem_pool_host: HostKVCache,
|
155
152
|
load_cache_event: threading.Event = None,
|
156
153
|
write_policy: str = "write_through_selective",
|
157
154
|
):
|
@@ -82,10 +82,12 @@ class DataParallelController:
|
|
82
82
|
self.scheduler_procs = []
|
83
83
|
self.workers = [None] * server_args.dp_size
|
84
84
|
|
85
|
-
if
|
86
|
-
dp_port_args = self.launch_dp_schedulers(server_args, port_args)
|
87
|
-
else:
|
85
|
+
if server_args.enable_dp_attention:
|
88
86
|
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
|
87
|
+
self.control_message_step = server_args.tp_size
|
88
|
+
else:
|
89
|
+
dp_port_args = self.launch_dp_schedulers(server_args, port_args)
|
90
|
+
self.control_message_step = 1
|
89
91
|
|
90
92
|
# Only node rank 0 runs the real data parallel controller that dispatches the requests.
|
91
93
|
if server_args.node_rank == 0:
|
@@ -105,6 +107,7 @@ class DataParallelController:
|
|
105
107
|
threads = []
|
106
108
|
sockets = []
|
107
109
|
dp_port_args = []
|
110
|
+
ready_events = []
|
108
111
|
for dp_rank in range(server_args.dp_size):
|
109
112
|
tmp_port_args = PortArgs.init_new(server_args)
|
110
113
|
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
|
@@ -115,10 +118,13 @@ class DataParallelController:
|
|
115
118
|
# We hold it first so that the next dp worker gets a different port
|
116
119
|
sockets.append(bind_port(tmp_port_args.nccl_port))
|
117
120
|
|
121
|
+
ready_event = threading.Event()
|
122
|
+
ready_events.append(ready_event)
|
123
|
+
|
118
124
|
# Create a thread for each worker
|
119
125
|
thread = threading.Thread(
|
120
|
-
target=self.
|
121
|
-
args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
|
126
|
+
target=self.launch_tensor_parallel_group_thread,
|
127
|
+
args=(server_args, tmp_port_args, base_gpu_id, dp_rank, ready_event),
|
122
128
|
)
|
123
129
|
threads.append(thread)
|
124
130
|
base_gpu_id += server_args.tp_size * server_args.gpu_id_step
|
@@ -130,11 +136,27 @@ class DataParallelController:
|
|
130
136
|
# Start all threads
|
131
137
|
for thread in threads:
|
132
138
|
thread.start()
|
133
|
-
for
|
134
|
-
|
139
|
+
for event in ready_events:
|
140
|
+
event.wait()
|
135
141
|
|
136
142
|
return dp_port_args
|
137
143
|
|
144
|
+
def launch_tensor_parallel_group_thread(
|
145
|
+
self,
|
146
|
+
server_args: ServerArgs,
|
147
|
+
port_args: PortArgs,
|
148
|
+
base_gpu_id: int,
|
149
|
+
dp_rank: int,
|
150
|
+
ready_event: threading.Event,
|
151
|
+
):
|
152
|
+
self.launch_tensor_parallel_group(server_args, port_args, base_gpu_id, dp_rank)
|
153
|
+
ready_event.set()
|
154
|
+
|
155
|
+
# This thread cannot be closed because otherwise the `kill_itself_when_parent_died`
|
156
|
+
# function in scheduler.py will kill the scheduler.
|
157
|
+
while True:
|
158
|
+
pass
|
159
|
+
|
138
160
|
def launch_dp_attention_schedulers(self, server_args, port_args):
|
139
161
|
self.launch_tensor_parallel_group(server_args, port_args, 0, None)
|
140
162
|
dp_port_args = []
|
@@ -223,7 +245,7 @@ class DataParallelController:
|
|
223
245
|
self.dispatching(recv_req)
|
224
246
|
else:
|
225
247
|
# Send other control messages to first worker of tp group
|
226
|
-
for worker in self.workers[:: self.
|
248
|
+
for worker in self.workers[:: self.control_message_step]:
|
227
249
|
worker.send_pyobj(recv_req)
|
228
250
|
|
229
251
|
|
@@ -0,0 +1,81 @@
|
|
1
|
+
import json
|
2
|
+
import logging
|
3
|
+
import time
|
4
|
+
from collections import defaultdict
|
5
|
+
from typing import Dict, List, Tuple
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
|
12
|
+
# global expert distribution recording
|
13
|
+
class ExpertDistributionRecorder:
|
14
|
+
# This class is a singleton class
|
15
|
+
def __new__(cls):
|
16
|
+
if not hasattr(cls, "instance"):
|
17
|
+
cls.instance = super(ExpertDistributionRecorder, cls).__new__(cls)
|
18
|
+
return cls.instance
|
19
|
+
|
20
|
+
def __init__(self):
|
21
|
+
# the length of the dictionary is the number of layers
|
22
|
+
# the length of the list is the number of tokens
|
23
|
+
# the length of the tuple is topk's k value
|
24
|
+
self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict(
|
25
|
+
list
|
26
|
+
)
|
27
|
+
self._record = False
|
28
|
+
self._current_layer_id = "UNKNOWN"
|
29
|
+
|
30
|
+
def set_current_layer(self, layer_idx):
|
31
|
+
self._current_layer_id = layer_idx
|
32
|
+
|
33
|
+
def record_new_token(self, topk_ids):
|
34
|
+
if not self._record:
|
35
|
+
return
|
36
|
+
topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
|
37
|
+
torch.cuda.synchronize()
|
38
|
+
for i in topk_ids_list:
|
39
|
+
self._expert_distribution_record[self._current_layer_id].append(tuple(i))
|
40
|
+
|
41
|
+
def reset(self):
|
42
|
+
"""Reset the expert distribution recorder."""
|
43
|
+
logger.info("Resetting expert distribution record...")
|
44
|
+
self._record = False
|
45
|
+
self._expert_distribution_record.clear()
|
46
|
+
self._current_layer_id = "UNKNOWN"
|
47
|
+
|
48
|
+
def start_record(self):
|
49
|
+
"""Start recording the expert distribution. Reset the recorder and set the recording flag to True."""
|
50
|
+
if self._record == True:
|
51
|
+
logger.warning(
|
52
|
+
"SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?"
|
53
|
+
)
|
54
|
+
self.reset()
|
55
|
+
self._record = True
|
56
|
+
|
57
|
+
def stop_record(self):
|
58
|
+
"""Stop recording the expert distribution. Set the recording flag to False."""
|
59
|
+
if self._record == False:
|
60
|
+
logger.warning(
|
61
|
+
"SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?"
|
62
|
+
)
|
63
|
+
self._record = False
|
64
|
+
|
65
|
+
def dump_record(self):
|
66
|
+
"""Dump the expert distribution record to a file. Reset the recorder after dumping."""
|
67
|
+
results = {}
|
68
|
+
for layer_idx, layer_record in self._expert_distribution_record.items():
|
69
|
+
results[layer_idx] = defaultdict(int)
|
70
|
+
for token_record in layer_record:
|
71
|
+
for expert_idx in token_record:
|
72
|
+
results[layer_idx][expert_idx] += 1
|
73
|
+
with open(
|
74
|
+
f"expert_distribution_rank{torch.distributed.get_rank()}_timestamp{time.time()}.csv",
|
75
|
+
"w",
|
76
|
+
) as fd:
|
77
|
+
fd.write("layer_id,expert_id,count\n")
|
78
|
+
for layer_idx, layer_results in results.items():
|
79
|
+
for expert_idx, count in layer_results.items():
|
80
|
+
fd.write(f"{layer_idx},{expert_idx},{count}\n")
|
81
|
+
self.reset()
|