sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +1 -1
- sglang/lang/chat_template.py +29 -0
- sglang/srt/_custom_ops.py +19 -17
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/janus_pro.py +629 -0
- sglang/srt/configs/model_config.py +24 -14
- sglang/srt/conversation.py +80 -2
- sglang/srt/custom_op.py +64 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
- sglang/srt/distributed/parallel_state.py +10 -1
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/http_server.py +1 -1
- sglang/srt/function_call_parser.py +33 -2
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
- sglang/srt/layers/attention/triton_backend.py +1 -3
- sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
- sglang/srt/layers/attention/vision.py +43 -62
- sglang/srt/layers/dp_attention.py +30 -2
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/linear.py +1 -1
- sglang/srt/layers/logits_processor.py +1 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +25 -9
- sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
- sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/parameter.py +10 -0
- sglang/srt/layers/quantization/__init__.py +90 -68
- sglang/srt/layers/quantization/blockwise_int8.py +1 -2
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +174 -106
- sglang/srt/layers/quantization/fp8_kernel.py +210 -38
- sglang/srt/layers/quantization/fp8_utils.py +156 -15
- sglang/srt/layers/quantization/modelopt_quant.py +5 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
- sglang/srt/layers/quantization/w8a8_int8.py +152 -3
- sglang/srt/layers/rotary_embedding.py +5 -3
- sglang/srt/layers/sampler.py +29 -35
- sglang/srt/layers/vocab_parallel_embedding.py +0 -1
- sglang/srt/lora/backend/__init__.py +9 -12
- sglang/srt/managers/cache_controller.py +74 -8
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/image_processor.py +37 -631
- sglang/srt/managers/image_processors/base_image_processor.py +219 -0
- sglang/srt/managers/image_processors/janus_pro.py +79 -0
- sglang/srt/managers/image_processors/llava.py +152 -0
- sglang/srt/managers/image_processors/minicpmv.py +86 -0
- sglang/srt/managers/image_processors/mlama.py +60 -0
- sglang/srt/managers/image_processors/qwen_vl.py +161 -0
- sglang/srt/managers/io_struct.py +32 -15
- sglang/srt/managers/multi_modality_padding.py +134 -0
- sglang/srt/managers/schedule_batch.py +213 -118
- sglang/srt/managers/schedule_policy.py +40 -8
- sglang/srt/managers/scheduler.py +176 -683
- sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
- sglang/srt/managers/tokenizer_manager.py +6 -6
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
- sglang/srt/mem_cache/base_prefix_cache.py +6 -8
- sglang/srt/mem_cache/chunk_cache.py +12 -44
- sglang/srt/mem_cache/hiradix_cache.py +71 -34
- sglang/srt/mem_cache/memory_pool.py +81 -17
- sglang/srt/mem_cache/paged_allocator.py +283 -0
- sglang/srt/mem_cache/radix_cache.py +117 -36
- sglang/srt/model_executor/cuda_graph_runner.py +68 -20
- sglang/srt/model_executor/forward_batch_info.py +23 -10
- sglang/srt/model_executor/model_runner.py +63 -63
- sglang/srt/model_loader/loader.py +2 -1
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/deepseek_janus_pro.py +2127 -0
- sglang/srt/models/deepseek_nextn.py +23 -3
- sglang/srt/models/deepseek_v2.py +200 -191
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/minicpmv.py +28 -89
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/qwen2.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +25 -50
- sglang/srt/models/qwen2_vl.py +33 -49
- sglang/srt/openai_api/adapter.py +59 -35
- sglang/srt/openai_api/protocol.py +8 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
- sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
- sglang/srt/server_args.py +24 -16
- sglang/srt/speculative/eagle_worker.py +75 -39
- sglang/srt/utils.py +104 -9
- sglang/test/runners.py +104 -10
- sglang/test/test_block_fp8.py +106 -16
- sglang/test/test_custom_ops.py +88 -0
- sglang/test/test_utils.py +20 -4
- sglang/utils.py +0 -4
- sglang/version.py +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,8 @@
|
|
1
|
-
from typing import Any, Dict, List, Optional
|
1
|
+
from typing import Any, Callable, Dict, List, Optional
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from sglang.srt.utils import is_cuda_available
|
5
|
+
from sglang.srt.utils import is_cuda_available, set_weight_attrs
|
6
6
|
|
7
7
|
is_cuda = is_cuda_available()
|
8
8
|
if is_cuda:
|
@@ -10,6 +10,7 @@ if is_cuda:
|
|
10
10
|
|
11
11
|
from torch.nn.parameter import Parameter
|
12
12
|
|
13
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
13
14
|
from sglang.srt.layers.linear import LinearMethodBase
|
14
15
|
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
15
16
|
from sglang.srt.layers.quantization.base_config import (
|
@@ -55,9 +56,12 @@ class W8A8Int8Config(QuantizationConfig):
|
|
55
56
|
prefix: str,
|
56
57
|
) -> Optional["QuantizeMethodBase"]:
|
57
58
|
from sglang.srt.layers.linear import LinearBase
|
59
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
58
60
|
|
59
61
|
if isinstance(layer, LinearBase):
|
60
62
|
return W8A8Int8LinearMethod(self)
|
63
|
+
elif isinstance(layer, FusedMoE):
|
64
|
+
return W8A8Int8MoEMethod(self)
|
61
65
|
return None
|
62
66
|
|
63
67
|
def get_scaled_act_names(self) -> List[str]:
|
@@ -81,7 +85,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|
81
85
|
input_size: int,
|
82
86
|
output_size: int,
|
83
87
|
params_dtype: torch.dtype,
|
84
|
-
**extra_weight_attrs
|
88
|
+
**extra_weight_attrs,
|
85
89
|
):
|
86
90
|
|
87
91
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
@@ -115,3 +119,148 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|
115
119
|
return int8_scaled_mm(
|
116
120
|
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
|
117
121
|
)
|
122
|
+
|
123
|
+
|
124
|
+
class W8A8Int8MoEMethod:
|
125
|
+
"""MoE method for INT8.
|
126
|
+
Supports loading INT8 checkpoints with static weight scale and
|
127
|
+
dynamic/static activation scale.
|
128
|
+
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
|
129
|
+
activation scaling. The weight scaling factor will be initialized after
|
130
|
+
the model weights are loaded.
|
131
|
+
Args:
|
132
|
+
quant_config: The quantization config.
|
133
|
+
"""
|
134
|
+
|
135
|
+
def __new__(cls, *args, **kwargs):
|
136
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
137
|
+
|
138
|
+
if not hasattr(cls, "_initialized"):
|
139
|
+
original_init = cls.__init__
|
140
|
+
new_cls = type(
|
141
|
+
cls.__name__,
|
142
|
+
(FusedMoEMethodBase,),
|
143
|
+
{
|
144
|
+
"__init__": original_init,
|
145
|
+
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
146
|
+
},
|
147
|
+
)
|
148
|
+
obj = super(new_cls, new_cls).__new__(new_cls)
|
149
|
+
obj.__init__(*args, **kwargs)
|
150
|
+
return obj
|
151
|
+
return super().__new__(cls)
|
152
|
+
|
153
|
+
def __init__(self, quant_config):
|
154
|
+
self.quant_config = quant_config
|
155
|
+
|
156
|
+
def create_weights(
|
157
|
+
self,
|
158
|
+
layer: torch.nn.Module,
|
159
|
+
num_experts: int,
|
160
|
+
hidden_size: int,
|
161
|
+
intermediate_size: int,
|
162
|
+
params_dtype: torch.dtype,
|
163
|
+
**extra_weight_attrs,
|
164
|
+
):
|
165
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
166
|
+
|
167
|
+
tp_size = get_tensor_model_parallel_world_size()
|
168
|
+
|
169
|
+
# WEIGHTS
|
170
|
+
w13_weight = torch.nn.Parameter(
|
171
|
+
torch.empty(
|
172
|
+
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
|
173
|
+
),
|
174
|
+
requires_grad=False,
|
175
|
+
)
|
176
|
+
layer.register_parameter("w13_weight", w13_weight)
|
177
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
178
|
+
|
179
|
+
w2_weight = torch.nn.Parameter(
|
180
|
+
torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
|
181
|
+
requires_grad=False,
|
182
|
+
)
|
183
|
+
layer.register_parameter("w2_weight", w2_weight)
|
184
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
185
|
+
|
186
|
+
w13_weight_scale = torch.nn.Parameter(
|
187
|
+
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
|
188
|
+
requires_grad=False,
|
189
|
+
)
|
190
|
+
w2_weight_scale = torch.nn.Parameter(
|
191
|
+
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
|
192
|
+
requires_grad=False,
|
193
|
+
)
|
194
|
+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
195
|
+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
196
|
+
|
197
|
+
extra_weight_attrs.update(
|
198
|
+
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
199
|
+
)
|
200
|
+
|
201
|
+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
202
|
+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
203
|
+
|
204
|
+
w13_input_scale = None
|
205
|
+
layer.register_parameter("w13_input_scale", w13_input_scale)
|
206
|
+
|
207
|
+
w2_input_scale = None
|
208
|
+
layer.register_parameter("w2_input_scale", w2_input_scale)
|
209
|
+
|
210
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
211
|
+
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
|
212
|
+
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
|
213
|
+
layer.w13_weight_scale = Parameter(
|
214
|
+
layer.w13_weight_scale.data, requires_grad=False
|
215
|
+
)
|
216
|
+
layer.w2_weight_scale = Parameter(
|
217
|
+
layer.w2_weight_scale.data, requires_grad=False
|
218
|
+
)
|
219
|
+
|
220
|
+
def apply(
|
221
|
+
self,
|
222
|
+
layer: torch.nn.Module,
|
223
|
+
x: torch.Tensor,
|
224
|
+
router_logits: torch.Tensor,
|
225
|
+
top_k: int,
|
226
|
+
renormalize: bool,
|
227
|
+
use_grouped_topk: bool,
|
228
|
+
topk_group: Optional[int] = None,
|
229
|
+
num_expert_group: Optional[int] = None,
|
230
|
+
custom_routing_function: Optional[Callable] = None,
|
231
|
+
correction_bias: Optional[torch.Tensor] = None,
|
232
|
+
activation: str = "silu",
|
233
|
+
inplace: bool = True,
|
234
|
+
no_combine: bool = False,
|
235
|
+
) -> torch.Tensor:
|
236
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
237
|
+
from sglang.srt.layers.moe.topk import select_experts
|
238
|
+
|
239
|
+
# Expert selection
|
240
|
+
topk_weights, topk_ids = select_experts(
|
241
|
+
hidden_states=x,
|
242
|
+
router_logits=router_logits,
|
243
|
+
use_grouped_topk=use_grouped_topk,
|
244
|
+
top_k=top_k,
|
245
|
+
renormalize=renormalize,
|
246
|
+
topk_group=topk_group,
|
247
|
+
num_expert_group=num_expert_group,
|
248
|
+
custom_routing_function=custom_routing_function,
|
249
|
+
correction_bias=correction_bias,
|
250
|
+
)
|
251
|
+
|
252
|
+
return fused_experts(
|
253
|
+
x,
|
254
|
+
layer.w13_weight,
|
255
|
+
layer.w2_weight,
|
256
|
+
topk_weights=topk_weights,
|
257
|
+
topk_ids=topk_ids,
|
258
|
+
inplace=inplace,
|
259
|
+
activation=activation,
|
260
|
+
use_int8_w8a8=True,
|
261
|
+
w1_scale=(layer.w13_weight_scale),
|
262
|
+
w2_scale=(layer.w2_weight_scale),
|
263
|
+
a1_scale=layer.w13_input_scale,
|
264
|
+
a2_scale=layer.w2_input_scale,
|
265
|
+
no_combine=no_combine,
|
266
|
+
)
|
@@ -403,12 +403,12 @@ def _yarn_find_correction_range(
|
|
403
403
|
|
404
404
|
|
405
405
|
def _yarn_linear_ramp_mask(
|
406
|
-
low: float, high: float, dim: int, dtype: torch.dtype
|
406
|
+
low: float, high: float, dim: int, dtype: torch.dtype, device: torch.device = None
|
407
407
|
) -> torch.Tensor:
|
408
408
|
if low == high:
|
409
409
|
high += 0.001 # Prevent singularity
|
410
410
|
|
411
|
-
linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
|
411
|
+
linear_func = (torch.arange(dim, dtype=dtype, device=device) - low) / (high - low)
|
412
412
|
ramp_func = torch.clamp(linear_func, 0, 1)
|
413
413
|
return ramp_func
|
414
414
|
|
@@ -688,7 +688,9 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
688
688
|
# Get n-d rotational scaling corrected for extrapolation
|
689
689
|
inv_freq_mask = (
|
690
690
|
1
|
691
|
-
- _yarn_linear_ramp_mask(
|
691
|
+
- _yarn_linear_ramp_mask(
|
692
|
+
low, high, self.rotary_dim // 2, dtype=torch.float, device=self.device
|
693
|
+
)
|
692
694
|
) * self.extrapolation_factor
|
693
695
|
inv_freq = (
|
694
696
|
inv_freq_interpolation * (1 - inv_freq_mask)
|
sglang/srt/layers/sampler.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
import logging
|
2
|
-
from typing import List
|
2
|
+
from typing import List
|
3
3
|
|
4
4
|
import torch
|
5
5
|
import torch.distributed as dist
|
@@ -42,7 +42,6 @@ class Sampler(nn.Module):
|
|
42
42
|
return_logprob: bool,
|
43
43
|
top_logprobs_nums: List[int],
|
44
44
|
token_ids_logprobs: List[List[int]],
|
45
|
-
batch_next_token_ids: Optional[torch.Tensor] = None,
|
46
45
|
):
|
47
46
|
"""Run a sampler & compute logprobs and update logits_output accordingly.
|
48
47
|
|
@@ -72,8 +71,7 @@ class Sampler(nn.Module):
|
|
72
71
|
|
73
72
|
if sampling_info.is_all_greedy:
|
74
73
|
# Use torch.argmax if all requests use greedy sampling
|
75
|
-
|
76
|
-
batch_next_token_ids = torch.argmax(logits, -1)
|
74
|
+
batch_next_token_ids = torch.argmax(logits, -1)
|
77
75
|
if return_logprob:
|
78
76
|
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
79
77
|
else:
|
@@ -94,43 +92,39 @@ class Sampler(nn.Module):
|
|
94
92
|
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
95
93
|
).clamp(min=torch.finfo(probs.dtype).min)
|
96
94
|
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
95
|
+
max_top_k_round, batch_size = 32, probs.shape[0]
|
96
|
+
uniform_samples = torch.rand(
|
97
|
+
(max_top_k_round, batch_size), device=probs.device
|
98
|
+
)
|
99
|
+
if sampling_info.need_min_p_sampling:
|
100
|
+
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
101
|
+
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
102
|
+
batch_next_token_ids = min_p_sampling_from_probs(
|
103
|
+
probs, uniform_samples, sampling_info.min_ps
|
101
104
|
)
|
102
|
-
|
103
|
-
|
104
|
-
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
105
|
-
batch_next_token_ids = min_p_sampling_from_probs(
|
106
|
-
probs, uniform_samples, sampling_info.min_ps
|
107
|
-
)
|
108
|
-
else:
|
109
|
-
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
110
|
-
probs,
|
111
|
-
uniform_samples,
|
112
|
-
sampling_info.top_ks,
|
113
|
-
sampling_info.top_ps,
|
114
|
-
filter_apply_order="joint",
|
115
|
-
)
|
116
|
-
|
117
|
-
if self.use_nan_detection and not torch.all(success):
|
118
|
-
logger.warning("Detected errors during sampling!")
|
119
|
-
batch_next_token_ids = torch.zeros_like(
|
120
|
-
batch_next_token_ids
|
121
|
-
)
|
122
|
-
|
123
|
-
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
124
|
-
if batch_next_token_ids is None:
|
125
|
-
# A slower fallback implementation with torch native operations.
|
126
|
-
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
105
|
+
else:
|
106
|
+
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
127
107
|
probs,
|
108
|
+
uniform_samples,
|
128
109
|
sampling_info.top_ks,
|
129
110
|
sampling_info.top_ps,
|
130
|
-
|
131
|
-
sampling_info.need_min_p_sampling,
|
111
|
+
filter_apply_order="joint",
|
132
112
|
)
|
133
113
|
|
114
|
+
if self.use_nan_detection and not torch.all(success):
|
115
|
+
logger.warning("Detected errors during sampling!")
|
116
|
+
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
117
|
+
|
118
|
+
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
119
|
+
# A slower fallback implementation with torch native operations.
|
120
|
+
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
121
|
+
probs,
|
122
|
+
sampling_info.top_ks,
|
123
|
+
sampling_info.top_ps,
|
124
|
+
sampling_info.min_ps,
|
125
|
+
sampling_info.need_min_p_sampling,
|
126
|
+
)
|
127
|
+
|
134
128
|
if return_logprob:
|
135
129
|
# clamp to avoid -inf
|
136
130
|
logprobs = torch.log(
|
@@ -264,7 +264,6 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
264
264
|
quant_method = None
|
265
265
|
if quant_config is not None:
|
266
266
|
quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
267
|
-
print("quant_method", quant_method)
|
268
267
|
if quant_method is None:
|
269
268
|
quant_method = UnquantizedEmbeddingMethod()
|
270
269
|
|
@@ -1,23 +1,20 @@
|
|
1
|
-
from .base_backend import BaseLoRABackend
|
2
|
-
from .flashinfer_backend import FlashInferLoRABackend
|
3
|
-
from .triton_backend import TritonLoRABackend
|
1
|
+
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
4
2
|
|
5
3
|
|
6
4
|
def get_backend_from_name(name: str) -> BaseLoRABackend:
|
7
5
|
"""
|
8
6
|
Get corresponding backend class from backend's name
|
9
7
|
"""
|
10
|
-
|
11
|
-
|
12
|
-
"flashinfer": FlashInferLoRABackend,
|
13
|
-
}
|
8
|
+
if name == "triton":
|
9
|
+
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
14
10
|
|
15
|
-
|
16
|
-
|
11
|
+
return TritonLoRABackend
|
12
|
+
elif name == "flashinfer":
|
13
|
+
from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
|
17
14
|
|
18
|
-
|
19
|
-
|
20
|
-
|
15
|
+
return FlashInferLoRABackend
|
16
|
+
else:
|
17
|
+
raise ValueError(f"Invalid backend: {name}")
|
21
18
|
|
22
19
|
|
23
20
|
__all__ = [
|
@@ -22,11 +22,34 @@ from typing import List, Optional
|
|
22
22
|
|
23
23
|
import torch
|
24
24
|
|
25
|
-
from sglang.srt.mem_cache.memory_pool import
|
25
|
+
from sglang.srt.mem_cache.memory_pool import (
|
26
|
+
MHATokenToKVPoolHost,
|
27
|
+
TokenToKVPoolAllocator,
|
28
|
+
)
|
26
29
|
|
27
30
|
logger = logging.getLogger(__name__)
|
28
31
|
|
29
32
|
|
33
|
+
class LayerDoneCounter:
|
34
|
+
def __init__(self, num_layers):
|
35
|
+
self.counter = num_layers
|
36
|
+
self.condition = threading.Condition()
|
37
|
+
|
38
|
+
def increment(self):
|
39
|
+
with self.condition:
|
40
|
+
self.counter += 1
|
41
|
+
self.condition.notify_all()
|
42
|
+
|
43
|
+
def wait_until(self, threshold):
|
44
|
+
with self.condition:
|
45
|
+
while self.counter <= threshold:
|
46
|
+
self.condition.wait()
|
47
|
+
|
48
|
+
def reset(self):
|
49
|
+
with self.condition:
|
50
|
+
self.counter = 0
|
51
|
+
|
52
|
+
|
30
53
|
class CacheOperation:
|
31
54
|
|
32
55
|
counter = 0
|
@@ -127,15 +150,20 @@ class HiCacheController:
|
|
127
150
|
|
128
151
|
def __init__(
|
129
152
|
self,
|
130
|
-
|
153
|
+
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
131
154
|
mem_pool_host: MHATokenToKVPoolHost,
|
155
|
+
load_cache_event: threading.Event = None,
|
132
156
|
write_policy: str = "write_through_selective",
|
133
157
|
):
|
134
|
-
|
135
|
-
self.mem_pool_device =
|
158
|
+
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
159
|
+
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
136
160
|
self.mem_pool_host = mem_pool_host
|
137
161
|
self.write_policy = write_policy
|
138
162
|
|
163
|
+
self.load_cache_event = load_cache_event
|
164
|
+
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
165
|
+
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
|
166
|
+
|
139
167
|
if write_policy not in [
|
140
168
|
"write_through",
|
141
169
|
"write_through_selective",
|
@@ -162,7 +190,7 @@ class HiCacheController:
|
|
162
190
|
target=self.write_thread_func_buffer, daemon=True
|
163
191
|
)
|
164
192
|
self.load_thread = threading.Thread(
|
165
|
-
target=self.
|
193
|
+
target=self.load_thread_func_layer_by_layer, daemon=True
|
166
194
|
)
|
167
195
|
self.write_thread.start()
|
168
196
|
self.load_thread.start()
|
@@ -183,7 +211,7 @@ class HiCacheController:
|
|
183
211
|
target=self.write_thread_func_buffer, daemon=True
|
184
212
|
)
|
185
213
|
self.load_thread = threading.Thread(
|
186
|
-
target=self.
|
214
|
+
target=self.load_thread_func_layer_by_layer, daemon=True
|
187
215
|
)
|
188
216
|
self.stop_event.clear()
|
189
217
|
self.write_thread.start()
|
@@ -216,10 +244,12 @@ class HiCacheController:
|
|
216
244
|
"""
|
217
245
|
Load KV caches from host memory to device memory.
|
218
246
|
"""
|
219
|
-
device_indices = self.
|
247
|
+
device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
|
220
248
|
if device_indices is None:
|
221
249
|
return None
|
222
250
|
self.mem_pool_host.protect_load(host_indices)
|
251
|
+
# to ensure the device indices are ready before accessed by another CUDA stream
|
252
|
+
torch.cuda.current_stream().synchronize()
|
223
253
|
self.load_queue.put(
|
224
254
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
225
255
|
)
|
@@ -270,6 +300,42 @@ class HiCacheController:
|
|
270
300
|
except Exception as e:
|
271
301
|
logger.error(e)
|
272
302
|
|
303
|
+
def load_thread_func_layer_by_layer(self):
|
304
|
+
"""
|
305
|
+
Load KV caches from host memory to device memory layer by layer.
|
306
|
+
"""
|
307
|
+
with torch.cuda.stream(self.load_stream):
|
308
|
+
while not self.stop_event.is_set():
|
309
|
+
self.load_cache_event.wait(timeout=1)
|
310
|
+
if not self.load_cache_event.is_set():
|
311
|
+
continue
|
312
|
+
self.load_cache_event.clear()
|
313
|
+
|
314
|
+
batch_operation = None
|
315
|
+
while self.load_queue.qsize() > 0:
|
316
|
+
op = self.load_queue.get(block=True)
|
317
|
+
if batch_operation is None:
|
318
|
+
batch_operation = op
|
319
|
+
else:
|
320
|
+
batch_operation.merge(op)
|
321
|
+
if batch_operation is None:
|
322
|
+
continue
|
323
|
+
|
324
|
+
self.layer_done_counter.reset()
|
325
|
+
for i in range(self.mem_pool_host.layer_num):
|
326
|
+
flat_data = self.mem_pool_host.get_flat_data_by_layer(
|
327
|
+
batch_operation.host_indices, i
|
328
|
+
)
|
329
|
+
self.mem_pool_device.transfer_per_layer(
|
330
|
+
batch_operation.device_indices, flat_data, i
|
331
|
+
)
|
332
|
+
self.layer_done_counter.increment()
|
333
|
+
|
334
|
+
self.mem_pool_host.complete_io(batch_operation.host_indices)
|
335
|
+
for node_id in batch_operation.node_ids:
|
336
|
+
if node_id != 0:
|
337
|
+
self.ack_load_queue.put(node_id)
|
338
|
+
|
273
339
|
def write_aux_func(self, no_wait=False):
|
274
340
|
"""
|
275
341
|
Auxiliary function to prepare the buffer for write operations.
|
@@ -417,7 +483,7 @@ class HiCacheController:
|
|
417
483
|
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
418
484
|
) -> int:
|
419
485
|
if self.mem_pool_host.is_synced(host_indices):
|
420
|
-
self.
|
486
|
+
self.mem_pool_device_allocator.free(device_indices)
|
421
487
|
self.mem_pool_host.update_backup(host_indices)
|
422
488
|
return len(device_indices)
|
423
489
|
else:
|
@@ -54,7 +54,7 @@ class LoadBalanceMethod(Enum):
|
|
54
54
|
class DataParallelController:
|
55
55
|
"""A controller that dispatches requests to multiple data parallel workers."""
|
56
56
|
|
57
|
-
def __init__(self, server_args, port_args) -> None:
|
57
|
+
def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
|
58
58
|
# Parse args
|
59
59
|
self.max_total_num_tokens = None
|
60
60
|
self.server_args = server_args
|