sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,244 @@
|
|
1
|
+
from typing import Any, Callable, Dict, List, Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch.nn.parameter import Parameter
|
5
|
+
|
6
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
7
|
+
from sglang.srt.layers.linear import LinearMethodBase
|
8
|
+
from sglang.srt.layers.parameter import (
|
9
|
+
ChannelQuantScaleParameter,
|
10
|
+
GroupQuantScaleParameter,
|
11
|
+
ModelWeightParameter,
|
12
|
+
)
|
13
|
+
from sglang.srt.layers.quantization.base_config import (
|
14
|
+
QuantizationConfig,
|
15
|
+
QuantizeMethodBase,
|
16
|
+
)
|
17
|
+
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
18
|
+
from sglang.srt.utils import is_cuda
|
19
|
+
|
20
|
+
_is_cuda = is_cuda()
|
21
|
+
if _is_cuda:
|
22
|
+
from sgl_kernel import qserve_w4a8_per_chn_gemm, qserve_w4a8_per_group_gemm
|
23
|
+
|
24
|
+
|
25
|
+
QoQ_SUPPORTED_WEIGHT_BITS = [4]
|
26
|
+
QoQ_SUPPORTED_GROUP_SIZES = [-1, 128]
|
27
|
+
|
28
|
+
|
29
|
+
class QoQConfig(QuantizationConfig):
|
30
|
+
"""Config class for QoQ Quantization.
|
31
|
+
|
32
|
+
- Weight: static, per-channel/group, asymmetric
|
33
|
+
- Activation: dynamic, per-token, symmetric
|
34
|
+
|
35
|
+
Reference: https://arxiv.org/abs/2405.04532
|
36
|
+
https://github.com/mit-han-lab/omniserve
|
37
|
+
"""
|
38
|
+
|
39
|
+
def __init__(self, weight_bits: int, group_size: int) -> None:
|
40
|
+
self.weight_bits = weight_bits
|
41
|
+
self.group_size = group_size
|
42
|
+
|
43
|
+
# Verify
|
44
|
+
if self.weight_bits not in QoQ_SUPPORTED_WEIGHT_BITS:
|
45
|
+
raise ValueError(
|
46
|
+
f"QoQ does not support weight_bits = {self.weight_bits}. "
|
47
|
+
f"Only weight_bits = {QoQ_SUPPORTED_WEIGHT_BITS} "
|
48
|
+
"are supported."
|
49
|
+
)
|
50
|
+
if self.group_size not in QoQ_SUPPORTED_GROUP_SIZES:
|
51
|
+
raise ValueError(
|
52
|
+
f"QoQ does not support group_size = {self.group_size}. "
|
53
|
+
f"Only group_sizes = {QoQ_SUPPORTED_GROUP_SIZES} "
|
54
|
+
"are supported."
|
55
|
+
)
|
56
|
+
|
57
|
+
# 4 bits packed into 8 bit datatype.
|
58
|
+
self.pack_factor = 8 // self.weight_bits
|
59
|
+
|
60
|
+
def __repr__(self) -> str:
|
61
|
+
return "QoQConfig(weight_bits={}, group_size={})".format(
|
62
|
+
self.weight_bits, self.group_size
|
63
|
+
)
|
64
|
+
|
65
|
+
@classmethod
|
66
|
+
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
67
|
+
return [torch.float16]
|
68
|
+
|
69
|
+
@classmethod
|
70
|
+
def get_min_capability(cls) -> int:
|
71
|
+
return 80
|
72
|
+
|
73
|
+
@classmethod
|
74
|
+
def get_name(self) -> str:
|
75
|
+
return "qoq"
|
76
|
+
|
77
|
+
@classmethod
|
78
|
+
def get_config_filenames(cls) -> List[str]:
|
79
|
+
"""List of filenames to search for in the model directory."""
|
80
|
+
return [
|
81
|
+
"quant_config.json",
|
82
|
+
"quantize_config.json",
|
83
|
+
]
|
84
|
+
|
85
|
+
@classmethod
|
86
|
+
def from_config(cls, config: Dict[str, Any]) -> "QoQConfig":
|
87
|
+
weight_bits = cls.get_from_keys(config, ["wbits"])
|
88
|
+
group_size = cls.get_from_keys(config, ["group_size"])
|
89
|
+
return cls(weight_bits, group_size)
|
90
|
+
|
91
|
+
def get_quant_method(
|
92
|
+
self,
|
93
|
+
layer: torch.nn.Module,
|
94
|
+
prefix: str,
|
95
|
+
) -> Optional["QuantizeMethodBase"]:
|
96
|
+
from sglang.srt.layers.linear import LinearBase
|
97
|
+
|
98
|
+
if isinstance(layer, LinearBase):
|
99
|
+
return QoQLinearMethod(self)
|
100
|
+
return None
|
101
|
+
|
102
|
+
def get_scaled_act_names(self) -> List[str]:
|
103
|
+
return []
|
104
|
+
|
105
|
+
|
106
|
+
class QoQLinearMethod(LinearMethodBase):
|
107
|
+
"""Linear method for QoQ.
|
108
|
+
|
109
|
+
Args:
|
110
|
+
quant_config: The QoQ quantization config.
|
111
|
+
"""
|
112
|
+
|
113
|
+
def __init__(self, quant_config: QoQConfig):
|
114
|
+
self.quant_config = quant_config
|
115
|
+
|
116
|
+
def create_weights(
|
117
|
+
self,
|
118
|
+
layer: torch.nn.Module,
|
119
|
+
input_size_per_partition: int,
|
120
|
+
output_partition_sizes: List[int],
|
121
|
+
input_size: int,
|
122
|
+
output_size: int,
|
123
|
+
params_dtype: torch.dtype,
|
124
|
+
**extra_weight_attrs,
|
125
|
+
):
|
126
|
+
|
127
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
128
|
+
|
129
|
+
# Validate output_size_per_partition
|
130
|
+
output_size_per_partition = sum(output_partition_sizes)
|
131
|
+
if output_size_per_partition % 32 != 0:
|
132
|
+
raise ValueError(
|
133
|
+
f"Weight output_size_per_partition = "
|
134
|
+
f"{output_size_per_partition} is not divisible by 32."
|
135
|
+
)
|
136
|
+
|
137
|
+
# Validate input_size_per_partition
|
138
|
+
if input_size_per_partition % self.quant_config.pack_factor != 0:
|
139
|
+
raise ValueError(
|
140
|
+
f"Weight input_size_per_partition = "
|
141
|
+
f"{input_size_per_partition} is not divisible by "
|
142
|
+
f"pack_factor = {self.quant_config.pack_factor}."
|
143
|
+
)
|
144
|
+
if (
|
145
|
+
self.quant_config.group_size != -1
|
146
|
+
and input_size_per_partition % self.quant_config.group_size != 0
|
147
|
+
):
|
148
|
+
raise ValueError(
|
149
|
+
f"Weight input_size_per_partition = "
|
150
|
+
f"{input_size_per_partition} is not divisible by "
|
151
|
+
f"group_size = {self.quant_config.group_size}."
|
152
|
+
)
|
153
|
+
|
154
|
+
qweight = ModelWeightParameter(
|
155
|
+
data=torch.empty(
|
156
|
+
output_size_per_partition,
|
157
|
+
input_size_per_partition // self.quant_config.pack_factor,
|
158
|
+
dtype=torch.int8,
|
159
|
+
),
|
160
|
+
input_dim=1,
|
161
|
+
output_dim=0,
|
162
|
+
weight_loader=weight_loader,
|
163
|
+
)
|
164
|
+
layer.register_parameter("qweight", qweight)
|
165
|
+
|
166
|
+
s1_scales = ChannelQuantScaleParameter(
|
167
|
+
data=torch.empty(output_size_per_partition, dtype=torch.float16),
|
168
|
+
output_dim=0,
|
169
|
+
weight_loader=weight_loader,
|
170
|
+
)
|
171
|
+
layer.register_parameter("s1_scales", s1_scales)
|
172
|
+
|
173
|
+
if self.quant_config.group_size == -1:
|
174
|
+
s1_szeros = ChannelQuantScaleParameter(
|
175
|
+
data=torch.empty(output_size_per_partition, dtype=torch.float16),
|
176
|
+
output_dim=0,
|
177
|
+
weight_loader=weight_loader,
|
178
|
+
)
|
179
|
+
layer.register_parameter("s1_szeros", s1_szeros)
|
180
|
+
else:
|
181
|
+
s2_scales = GroupQuantScaleParameter(
|
182
|
+
data=torch.empty(
|
183
|
+
(
|
184
|
+
input_size_per_partition // self.quant_config.group_size,
|
185
|
+
output_size_per_partition,
|
186
|
+
),
|
187
|
+
dtype=torch.int8,
|
188
|
+
),
|
189
|
+
input_dim=0,
|
190
|
+
output_dim=1,
|
191
|
+
weight_loader=weight_loader,
|
192
|
+
)
|
193
|
+
layer.register_parameter("s2_scales", s2_scales)
|
194
|
+
|
195
|
+
s2_zeros = GroupQuantScaleParameter(
|
196
|
+
data=torch.empty(
|
197
|
+
(
|
198
|
+
input_size_per_partition // self.quant_config.group_size,
|
199
|
+
output_size_per_partition,
|
200
|
+
),
|
201
|
+
dtype=torch.int8,
|
202
|
+
),
|
203
|
+
input_dim=0,
|
204
|
+
output_dim=1,
|
205
|
+
weight_loader=weight_loader,
|
206
|
+
)
|
207
|
+
layer.register_parameter("s2_zeros", s2_zeros)
|
208
|
+
|
209
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
210
|
+
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
|
211
|
+
layer.s1_scales = Parameter(layer.s1_scales.data, requires_grad=False)
|
212
|
+
if self.quant_config.group_size == -1:
|
213
|
+
layer.s1_szeros = Parameter(layer.s1_szeros.data, requires_grad=False)
|
214
|
+
else:
|
215
|
+
layer.s2_scales = Parameter(layer.s2_scales.data, requires_grad=False)
|
216
|
+
layer.s2_zeros = Parameter(layer.s2_zeros.data, requires_grad=False)
|
217
|
+
|
218
|
+
def apply(
|
219
|
+
self,
|
220
|
+
layer: torch.nn.Module,
|
221
|
+
x: torch.Tensor,
|
222
|
+
bias: Optional[torch.Tensor] = None,
|
223
|
+
):
|
224
|
+
assert x.dtype == torch.float16, "QoQ only supports float16 input now"
|
225
|
+
if self.quant_config.group_size == -1:
|
226
|
+
x_q, x_scale, x_sum = per_token_quant_int8(
|
227
|
+
x, scale_dtype=x.dtype, cal_sum=True
|
228
|
+
)
|
229
|
+
out = qserve_w4a8_per_chn_gemm(
|
230
|
+
x_q, layer.qweight, layer.s1_scales, x_scale, layer.s1_szeros, x_sum
|
231
|
+
)
|
232
|
+
else:
|
233
|
+
x_q, x_scale = per_token_quant_int8(x, scale_dtype=x.dtype)
|
234
|
+
out = qserve_w4a8_per_group_gemm(
|
235
|
+
x_q,
|
236
|
+
layer.qweight,
|
237
|
+
layer.s2_zeros,
|
238
|
+
layer.s2_scales,
|
239
|
+
layer.s1_scales,
|
240
|
+
x_scale,
|
241
|
+
)
|
242
|
+
if bias is not None:
|
243
|
+
out = out + bias
|
244
|
+
return out
|
sglang/srt/layers/sampler.py
CHANGED
@@ -239,10 +239,6 @@ def top_p_normalize_probs_torch(
|
|
239
239
|
|
240
240
|
|
241
241
|
def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
|
242
|
-
assert len(top_logprobs_nums) == logprobs.shape[0], (
|
243
|
-
len(top_logprobs_nums),
|
244
|
-
logprobs.shape[0],
|
245
|
-
)
|
246
242
|
max_k = max(top_logprobs_nums)
|
247
243
|
ret = logprobs.topk(max_k, dim=1)
|
248
244
|
values = ret.values.tolist()
|
@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
|
|
13
13
|
get_tensor_model_parallel_world_size,
|
14
14
|
tensor_model_parallel_all_reduce,
|
15
15
|
)
|
16
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
16
17
|
from sglang.srt.layers.parameter import BasevLLMParameter
|
17
18
|
from sglang.srt.layers.quantization.base_config import (
|
18
19
|
QuantizationConfig,
|
@@ -214,12 +215,14 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
214
215
|
self,
|
215
216
|
num_embeddings: int,
|
216
217
|
embedding_dim: int,
|
218
|
+
*,
|
217
219
|
params_dtype: Optional[torch.dtype] = None,
|
218
220
|
org_num_embeddings: Optional[int] = None,
|
219
221
|
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
220
222
|
quant_config: Optional[QuantizationConfig] = None,
|
221
223
|
prefix: str = "",
|
222
224
|
enable_tp: bool = True,
|
225
|
+
use_attn_tp_group: bool = False,
|
223
226
|
use_presharded_weights: bool = False,
|
224
227
|
):
|
225
228
|
super().__init__()
|
@@ -227,9 +230,14 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
227
230
|
|
228
231
|
self.enable_tp = enable_tp
|
229
232
|
if self.enable_tp:
|
230
|
-
|
231
|
-
|
233
|
+
if use_attn_tp_group:
|
234
|
+
tp_rank = get_attention_tp_rank()
|
235
|
+
self.tp_size = get_attention_tp_size()
|
236
|
+
else:
|
237
|
+
tp_rank = get_tensor_model_parallel_rank()
|
238
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
232
239
|
else:
|
240
|
+
assert use_attn_tp_group is False
|
233
241
|
tp_rank = 0
|
234
242
|
self.tp_size = 1
|
235
243
|
|
@@ -519,22 +527,25 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|
519
527
|
self,
|
520
528
|
num_embeddings: int,
|
521
529
|
embedding_dim: int,
|
530
|
+
*,
|
522
531
|
bias: bool = False,
|
523
532
|
params_dtype: Optional[torch.dtype] = None,
|
524
533
|
org_num_embeddings: Optional[int] = None,
|
525
534
|
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
526
535
|
quant_config: Optional[QuantizationConfig] = None,
|
527
536
|
prefix: str = "",
|
537
|
+
use_attn_tp_group: bool = False,
|
528
538
|
use_presharded_weights: bool = False,
|
529
539
|
):
|
530
540
|
super().__init__(
|
531
541
|
num_embeddings,
|
532
542
|
embedding_dim,
|
533
|
-
params_dtype,
|
534
|
-
org_num_embeddings,
|
535
|
-
padding_size,
|
536
|
-
quant_config,
|
537
|
-
prefix,
|
543
|
+
params_dtype=params_dtype,
|
544
|
+
org_num_embeddings=org_num_embeddings,
|
545
|
+
padding_size=padding_size,
|
546
|
+
quant_config=quant_config,
|
547
|
+
prefix=prefix,
|
548
|
+
use_attn_tp_group=use_attn_tp_group,
|
538
549
|
use_presharded_weights=use_presharded_weights,
|
539
550
|
)
|
540
551
|
self.quant_config = quant_config
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -100,7 +100,7 @@ class LoRAManager:
|
|
100
100
|
self.configs[name] = LoRAConfig(path)
|
101
101
|
self.hf_target_names.update(self.configs[name].target_modules)
|
102
102
|
|
103
|
-
# Target lora weight names for lora_a and lora_b modules
|
103
|
+
# Target lora weight names for lora_a and lora_b modules respectively.
|
104
104
|
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
|
105
105
|
self.lora_weight_names: Set[Tuple[str]] = set(
|
106
106
|
[get_stacked_name(module) for module in self.hf_target_names]
|
@@ -170,9 +170,7 @@ class LoRAManager:
|
|
170
170
|
dim=0,
|
171
171
|
out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1],
|
172
172
|
)
|
173
|
-
self.cuda_graph_batch_info.max_len =
|
174
|
-
torch.max(self.cuda_graph_batch_info.seg_lens[:bs])
|
175
|
-
)
|
173
|
+
self.cuda_graph_batch_info.max_len = 1
|
176
174
|
|
177
175
|
for i, lora_path in enumerate(forward_batch.lora_paths):
|
178
176
|
self.cuda_graph_batch_info.weight_indices[i] = (
|
sglang/srt/lora/mem_pool.py
CHANGED
@@ -50,15 +50,15 @@ class LoRAMemoryPool:
|
|
50
50
|
self.uid_to_buffer_id: Dict[Optional[str], int] = {}
|
51
51
|
|
52
52
|
# Buffer idx -> lora uid in memory pool
|
53
|
-
# All uids are
|
54
|
-
# Here we don't
|
53
|
+
# All uids are initialized as empty strings for empty buffer slots
|
54
|
+
# Here we don't initialize to None since None is a valid uid
|
55
55
|
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
|
56
56
|
|
57
57
|
def get_lora_A_shape(
|
58
58
|
self, module_name: str, base_model: torch.nn.Module
|
59
59
|
) -> Tuple[int]:
|
60
60
|
"""
|
61
|
-
Given a module_name (might be a stacked name), return the hidden dims of modules'
|
61
|
+
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
62
62
|
"""
|
63
63
|
input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
64
64
|
c = get_stacked_multiply(module_name)
|
@@ -75,7 +75,7 @@ class LoRAMemoryPool:
|
|
75
75
|
self, module_name: str, base_model: torch.nn.Module
|
76
76
|
) -> Tuple[int]:
|
77
77
|
"""
|
78
|
-
Given a module_name (might be a stacked name), return the hidden dims of modules'
|
78
|
+
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
79
79
|
"""
|
80
80
|
_, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
81
81
|
c = get_stacked_multiply(module_name)
|
@@ -77,7 +77,7 @@ def _gate_up_lora_b_kernel(
|
|
77
77
|
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
78
78
|
)
|
79
79
|
|
80
|
-
#
|
80
|
+
# Iterate to compute the block in output matrix
|
81
81
|
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
82
82
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
83
83
|
x_tile = tl.load(
|
@@ -79,7 +79,7 @@ def _qkv_lora_b_kernel(
|
|
79
79
|
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
80
80
|
)
|
81
81
|
|
82
|
-
#
|
82
|
+
# Iterate to compute the block in output matrix
|
83
83
|
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
84
84
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
85
85
|
x_tile = tl.load(
|
@@ -67,7 +67,7 @@ def _sgemm_lora_a_kernel(
|
|
67
67
|
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
68
68
|
)
|
69
69
|
|
70
|
-
#
|
70
|
+
# Iterate to compute the block in output matrix
|
71
71
|
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
72
72
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
73
73
|
x_tile = tl.load(
|
@@ -69,7 +69,7 @@ def _sgemm_lora_b_kernel(
|
|
69
69
|
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
70
70
|
)
|
71
71
|
|
72
|
-
#
|
72
|
+
# Iterate to compute the block in output matrix
|
73
73
|
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
74
74
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
75
75
|
x_tile = tl.load(
|
sglang/srt/lora/utils.py
CHANGED
@@ -79,7 +79,7 @@ def get_hidden_dim(
|
|
79
79
|
module_name: str, config: AutoConfig, base_model: torch.nn.Module
|
80
80
|
) -> Tuple[int]:
|
81
81
|
"""
|
82
|
-
Given a module_name (might be a stacked name), return the hidden dims of modules'
|
82
|
+
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
83
83
|
"""
|
84
84
|
|
85
85
|
if hasattr(base_model, "get_hidden_dim"):
|
@@ -17,13 +17,13 @@ import logging
|
|
17
17
|
import multiprocessing as mp
|
18
18
|
import signal
|
19
19
|
import threading
|
20
|
+
import time
|
20
21
|
from enum import Enum, auto
|
21
22
|
|
22
23
|
import psutil
|
23
24
|
import setproctitle
|
24
25
|
import zmq
|
25
26
|
|
26
|
-
from sglang.srt.disaggregation.utils import DisaggregationMode
|
27
27
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
28
28
|
from sglang.srt.managers.io_struct import (
|
29
29
|
TokenizedEmbeddingReqInput,
|
@@ -158,7 +158,7 @@ class DataParallelController:
|
|
158
158
|
# This thread cannot be closed because otherwise the `kill_itself_when_parent_died`
|
159
159
|
# function in scheduler.py will kill the scheduler.
|
160
160
|
while True:
|
161
|
-
|
161
|
+
time.sleep(30 * 24 * 3600)
|
162
162
|
|
163
163
|
def launch_dp_attention_schedulers(self, server_args, port_args):
|
164
164
|
self.launch_tensor_parallel_group(server_args, port_args, 0, None)
|
@@ -210,7 +210,7 @@ class DataParallelController:
|
|
210
210
|
)
|
211
211
|
# compute zmq ports for this dp rank
|
212
212
|
rank_port_args = PortArgs.init_new(server_args, dp_rank)
|
213
|
-
# Data parallelism
|
213
|
+
# Data parallelism reuses the tensor parallelism group,
|
214
214
|
# so all dp ranks should use the same nccl port.
|
215
215
|
rank_port_args.nccl_port = port_args.nccl_port
|
216
216
|
|