sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +1 -11
- sglang/bench_serving.py +149 -1
- sglang/lang/chat_template.py +44 -0
- sglang/srt/configs/deepseekvl2.py +3 -0
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +17 -0
- sglang/srt/constrained/xgrammar_backend.py +11 -19
- sglang/srt/conversation.py +30 -3
- sglang/srt/disaggregation/decode.py +4 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +9 -18
- sglang/srt/disaggregation/nixl/conn.py +241 -71
- sglang/srt/disaggregation/utils.py +44 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +14 -2
- sglang/srt/entrypoints/http_server.py +28 -1
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +146 -50
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
- sglang/srt/layers/moe/ep_moe/layer.py +120 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +4 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +5 -0
- sglang/srt/layers/quantization/fp8.py +108 -95
- sglang/srt/layers/quantization/fp8_kernel.py +79 -60
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/lora/lora_manager.py +10 -13
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/schedule_batch.py +19 -1
- sglang/srt/managers/schedule_policy.py +11 -5
- sglang/srt/managers/scheduler.py +28 -13
- sglang/srt/managers/tokenizer_manager.py +24 -13
- sglang/srt/managers/tp_worker.py +9 -12
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/model_executor/model_runner.py +44 -33
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +55 -20
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +1 -1
- sglang/srt/models/llama4.py +53 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +24 -40
- sglang/srt/openai_api/protocol.py +28 -16
- sglang/srt/reasoning_parser.py +2 -2
- sglang/srt/sampling/sampling_batch_info.py +54 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +30 -6
- sglang/srt/utils.py +35 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +14 -6
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +90 -80
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,9 @@ except ImportError:
|
|
14
14
|
|
15
15
|
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
16
16
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
17
|
+
fp8_dtype,
|
18
|
+
fp8_max,
|
19
|
+
is_fp8_fnuz,
|
17
20
|
per_token_group_quant_fp8,
|
18
21
|
scaled_fp8_quant,
|
19
22
|
sglang_per_token_quant_fp8,
|
@@ -30,8 +33,11 @@ from sglang.srt.utils import (
|
|
30
33
|
|
31
34
|
_is_hip = is_hip()
|
32
35
|
_is_cuda = is_cuda()
|
36
|
+
_is_fp8_fnuz = is_fp8_fnuz()
|
33
37
|
|
34
|
-
|
38
|
+
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
|
39
|
+
|
40
|
+
if _is_hip and use_aiter_moe:
|
35
41
|
from aiter import gemm_a8w8_blockscale
|
36
42
|
|
37
43
|
if _is_cuda:
|
@@ -43,19 +49,23 @@ use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_K
|
|
43
49
|
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
44
50
|
TORCH_DEVICE_IDENTITY = None
|
45
51
|
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
# The condition is
|
55
|
-
#
|
56
|
-
|
57
|
-
|
58
|
-
)
|
52
|
+
|
53
|
+
def use_rowwise_torch_scaled_mm():
|
54
|
+
_TORCH_VERSION = torch.__version__.split("+")[0]
|
55
|
+
try:
|
56
|
+
_TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3]))
|
57
|
+
except ValueError:
|
58
|
+
_TORCH_VERSION_TUPLE = (0, 0, 0)
|
59
|
+
if _is_hip:
|
60
|
+
# The condition to determine if it is on a platform that supports
|
61
|
+
# torch._scaled_mm rowwise feature.
|
62
|
+
# The condition is determined once as the operations
|
63
|
+
# are time consuming.
|
64
|
+
return get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0)
|
65
|
+
return False
|
66
|
+
|
67
|
+
|
68
|
+
USE_ROWWISE_TORCH_SCALED_MM = use_rowwise_torch_scaled_mm()
|
59
69
|
|
60
70
|
|
61
71
|
def cutlass_fp8_supported():
|
@@ -132,7 +142,7 @@ def apply_w8a8_block_fp8_linear(
|
|
132
142
|
output = fp8_blockwise_scaled_mm(
|
133
143
|
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
|
134
144
|
)
|
135
|
-
elif _is_hip and
|
145
|
+
elif _is_hip and use_aiter_moe:
|
136
146
|
q_input, x_scale = per_token_group_quant_fp8(
|
137
147
|
input_2d, block_size[1], column_major_scales=False
|
138
148
|
)
|
@@ -164,18 +174,21 @@ def apply_w8a8_block_fp8_linear(
|
|
164
174
|
|
165
175
|
|
166
176
|
def input_to_float8(
|
167
|
-
x: torch.Tensor, dtype: torch.dtype =
|
177
|
+
x: torch.Tensor, dtype: torch.dtype = fp8_dtype
|
168
178
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
169
179
|
"""This function quantizes input values to float8 values with tensor-wise quantization."""
|
170
|
-
finfo = torch.finfo(dtype)
|
171
180
|
min_val, max_val = x.aminmax()
|
172
181
|
amax = torch.maximum(min_val.abs(), max_val.abs()).float().clamp(min=1e-12)
|
173
|
-
|
174
|
-
if
|
175
|
-
dtype =
|
176
|
-
|
177
|
-
|
178
|
-
|
182
|
+
|
183
|
+
if _is_fp8_fnuz:
|
184
|
+
dtype = fp8_dtype
|
185
|
+
fp_max = fp8_max
|
186
|
+
else:
|
187
|
+
finfo = torch.finfo(dtype)
|
188
|
+
fp_max = finfo.max
|
189
|
+
|
190
|
+
scale = fp_max / amax
|
191
|
+
x_scl_sat = (x.float() * scale).clamp(min=-fp_max, max=fp_max)
|
179
192
|
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
180
193
|
|
181
194
|
|
@@ -222,6 +235,41 @@ def block_quant_to_tensor_quant(
|
|
222
235
|
return x_q_tensor, scale
|
223
236
|
|
224
237
|
|
238
|
+
def block_quant_dequant(
|
239
|
+
x_q_block: torch.Tensor,
|
240
|
+
x_s: torch.Tensor,
|
241
|
+
block_size: List[int],
|
242
|
+
dtype: torch.dtype,
|
243
|
+
) -> torch.Tensor:
|
244
|
+
"""This function converts block-wise quantization to unquantized.
|
245
|
+
The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
|
246
|
+
and the block size.
|
247
|
+
The output is an unquantized tensor with dtype.
|
248
|
+
"""
|
249
|
+
block_n, block_k = block_size[0], block_size[1]
|
250
|
+
n, k = x_q_block.shape
|
251
|
+
n_tiles = (n + block_n - 1) // block_n
|
252
|
+
k_tiles = (k + block_k - 1) // block_k
|
253
|
+
assert n_tiles == x_s.shape[0]
|
254
|
+
assert k_tiles == x_s.shape[1]
|
255
|
+
|
256
|
+
x_dq_block = torch.empty_like(x_q_block, dtype=dtype)
|
257
|
+
|
258
|
+
for j in range(n_tiles):
|
259
|
+
for i in range(k_tiles):
|
260
|
+
x_q_block_tile = x_q_block[
|
261
|
+
j * block_n : min((j + 1) * block_n, n),
|
262
|
+
i * block_k : min((i + 1) * block_k, k),
|
263
|
+
]
|
264
|
+
x_dq_block_tile = x_dq_block[
|
265
|
+
j * block_n : min((j + 1) * block_n, n),
|
266
|
+
i * block_k : min((i + 1) * block_k, k),
|
267
|
+
]
|
268
|
+
x_dq_block_tile[:, :] = x_q_block_tile.to(torch.float32) * x_s[j][i]
|
269
|
+
|
270
|
+
return x_dq_block
|
271
|
+
|
272
|
+
|
225
273
|
def channel_quant_to_tensor_quant(
|
226
274
|
x_q_channel: torch.Tensor,
|
227
275
|
x_s: torch.Tensor,
|
@@ -8,10 +8,8 @@ from sglang.srt.layers.quantization.base_config import (
|
|
8
8
|
QuantizationConfig,
|
9
9
|
QuantizeMethodBase,
|
10
10
|
)
|
11
|
+
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
|
11
12
|
from sglang.srt.layers.radix_attention import RadixAttention
|
12
|
-
from sglang.srt.utils import is_hip
|
13
|
-
|
14
|
-
_is_hip = is_hip()
|
15
13
|
|
16
14
|
logger = logging.getLogger(__name__)
|
17
15
|
|
@@ -44,11 +42,6 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|
44
42
|
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
|
45
43
|
)
|
46
44
|
|
47
|
-
@classmethod
|
48
|
-
def is_fp8_fnuz(cls) -> bool:
|
49
|
-
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
50
|
-
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
|
51
|
-
|
52
45
|
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
53
46
|
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
|
54
47
|
|
@@ -57,7 +50,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|
57
50
|
# We prefer to use separate k_scale and v_scale if present
|
58
51
|
k_scale = layer.k_scale.to("cpu").tolist()
|
59
52
|
v_scale = layer.v_scale.to("cpu").tolist()
|
60
|
-
if
|
53
|
+
if is_fp8_fnuz():
|
61
54
|
k_scale *= 2
|
62
55
|
v_scale *= 2
|
63
56
|
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
|
@@ -73,7 +66,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|
73
66
|
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
|
74
67
|
k_scale = scale_to_duplicate.to("cpu").tolist()
|
75
68
|
v_scale = scale_to_duplicate.to("cpu").tolist()
|
76
|
-
if
|
69
|
+
if is_fp8_fnuz():
|
77
70
|
k_scale *= 2
|
78
71
|
v_scale *= 2
|
79
72
|
|
@@ -14,11 +14,6 @@ if not _is_cuda:
|
|
14
14
|
from vllm._custom_ops import scaled_fp8_quant
|
15
15
|
|
16
16
|
|
17
|
-
def is_fp8_fnuz() -> bool:
|
18
|
-
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
19
|
-
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
|
20
|
-
|
21
|
-
|
22
17
|
def is_layer_skipped(
|
23
18
|
prefix: str,
|
24
19
|
ignored_layers: List[str],
|
@@ -9,16 +9,20 @@ from sglang.srt.layers.quantization.base_config import (
|
|
9
9
|
QuantizationConfig,
|
10
10
|
QuantizeMethodBase,
|
11
11
|
)
|
12
|
-
from sglang.srt.layers.quantization.fp8_kernel import
|
12
|
+
from sglang.srt.layers.quantization.fp8_kernel import (
|
13
|
+
fp8_dtype,
|
14
|
+
is_fp8_fnuz,
|
15
|
+
per_token_group_quant_fp8,
|
16
|
+
)
|
13
17
|
from sglang.srt.layers.quantization.fp8_utils import (
|
14
18
|
apply_fp8_linear,
|
15
19
|
cutlass_fp8_supported,
|
16
20
|
input_to_float8,
|
17
21
|
normalize_e4m3fn_to_e4m3fnuz,
|
18
22
|
)
|
19
|
-
from sglang.srt.utils import
|
23
|
+
from sglang.srt.utils import set_weight_attrs
|
20
24
|
|
21
|
-
|
25
|
+
_is_fp8_fnuz = is_fp8_fnuz()
|
22
26
|
|
23
27
|
|
24
28
|
class W8A8Fp8Config(QuantizationConfig):
|
@@ -97,7 +101,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
|
97
101
|
if self.quantization_config.is_checkpoint_fp8_serialized:
|
98
102
|
weight_scale = layer.weight_scale.detach()
|
99
103
|
# If checkpoint offline quantized with w8a8_fp8, load the weight and weight_scale directly.
|
100
|
-
if
|
104
|
+
if _is_fp8_fnuz:
|
101
105
|
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
102
106
|
weight=weight, weight_scale=weight_scale
|
103
107
|
)
|
@@ -113,14 +117,9 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
|
113
117
|
layer.weight, layer.weight.shape[-1]
|
114
118
|
)
|
115
119
|
weight_scale = weight_scale.t().contiguous()
|
116
|
-
if _is_hip:
|
117
|
-
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
118
|
-
weight=weight, weight_scale=weight_scale
|
119
|
-
)
|
120
120
|
else:
|
121
121
|
# if cutlass not supported, we fall back to use torch._scaled_mm
|
122
122
|
# which requires per tensor quantization on weight
|
123
|
-
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
124
123
|
qweight, weight_scale = input_to_float8(layer.weight, dtype=fp8_dtype)
|
125
124
|
|
126
125
|
# Update the layer with the new values.
|
@@ -227,7 +226,6 @@ class W8A8FP8MoEMethod:
|
|
227
226
|
):
|
228
227
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
229
228
|
|
230
|
-
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
231
229
|
# WEIGHTS
|
232
230
|
w13_weight = torch.nn.Parameter(
|
233
231
|
torch.empty(
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -156,18 +156,15 @@ class LoRAManager:
|
|
156
156
|
# set up batch info shared by all lora modules
|
157
157
|
bs = forward_batch.batch_size
|
158
158
|
|
159
|
-
if
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
159
|
+
if (
|
160
|
+
hasattr(self, "max_bs_in_cuda_graph")
|
161
|
+
and bs <= self.max_bs_in_cuda_graph
|
162
|
+
and forward_batch.forward_mode.is_cuda_graph()
|
163
|
+
):
|
164
|
+
# Do in-place updates when CUDA graph is enabled and the batch forward mode
|
165
|
+
# could use CUDA graph.
|
164
166
|
self.cuda_graph_batch_info.bs = bs
|
165
|
-
|
166
|
-
self.cuda_graph_batch_info.seg_lens[:bs].copy_(
|
167
|
-
forward_batch.extend_seq_lens
|
168
|
-
)
|
169
|
-
else:
|
170
|
-
self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
|
167
|
+
self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
|
171
168
|
torch.cumsum(
|
172
169
|
self.cuda_graph_batch_info.seg_lens[:bs],
|
173
170
|
dim=0,
|
@@ -201,10 +198,10 @@ class LoRAManager:
|
|
201
198
|
max_len = int(torch.max(seg_lens))
|
202
199
|
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
203
200
|
|
204
|
-
lora_ranks = torch.
|
201
|
+
lora_ranks = torch.zeros(
|
205
202
|
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
|
206
203
|
)
|
207
|
-
scalings = torch.
|
204
|
+
scalings = torch.zeros(
|
208
205
|
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
|
209
206
|
)
|
210
207
|
for i, lora_path in enumerate(forward_batch.lora_paths):
|
@@ -268,98 +268,97 @@ class HiCacheController:
|
|
268
268
|
"""
|
269
269
|
Directly write through KV caches to host memory without buffering.
|
270
270
|
"""
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
271
|
+
torch.cuda.set_stream(self.write_stream)
|
272
|
+
while not self.stop_event.is_set():
|
273
|
+
try:
|
274
|
+
operation = self.write_queue.get(block=True, timeout=1)
|
275
|
+
self.mem_pool_host.write_page_all_layers(
|
276
|
+
operation.host_indices,
|
277
|
+
operation.device_indices,
|
278
|
+
self.mem_pool_device,
|
279
|
+
)
|
280
|
+
self.write_stream.synchronize()
|
281
|
+
self.mem_pool_host.complete_io(operation.host_indices)
|
282
|
+
for node_id in operation.node_ids:
|
283
|
+
if node_id != 0:
|
284
|
+
self.ack_write_queue.put(node_id)
|
285
|
+
except Empty:
|
286
|
+
continue
|
287
|
+
except Exception as e:
|
288
|
+
logger.error(e)
|
289
289
|
|
290
290
|
def load_thread_func_direct(self):
|
291
291
|
"""
|
292
292
|
Directly load KV caches from host memory to device memory without buffering.
|
293
293
|
"""
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
except Exception as e:
|
312
|
-
logger.error(e)
|
294
|
+
torch.cuda.set_stream(self.load_stream)
|
295
|
+
while not self.stop_event.is_set():
|
296
|
+
try:
|
297
|
+
operation = self.load_queue.get(block=True, timeout=1)
|
298
|
+
# time.sleep(18e-6 * len(operation.host_indices))
|
299
|
+
operation.data = self.mem_pool_host.get_flat_data(
|
300
|
+
operation.host_indices
|
301
|
+
)
|
302
|
+
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
303
|
+
self.mem_pool_host.complete_io(operation.host_indices)
|
304
|
+
for node_id in operation.node_ids:
|
305
|
+
if node_id != 0:
|
306
|
+
self.ack_load_queue.put(node_id)
|
307
|
+
except Empty:
|
308
|
+
continue
|
309
|
+
except Exception as e:
|
310
|
+
logger.error(e)
|
313
311
|
|
314
312
|
def load_thread_func_layer_by_layer(self):
|
315
313
|
"""
|
316
314
|
Load KV caches from host memory to device memory layer by layer.
|
317
315
|
"""
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
316
|
+
torch.cuda.set_stream(self.load_stream)
|
317
|
+
while not self.stop_event.is_set():
|
318
|
+
self.load_cache_event.wait(timeout=1)
|
319
|
+
if not self.load_cache_event.is_set():
|
320
|
+
continue
|
321
|
+
self.load_cache_event.clear()
|
324
322
|
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
if batch_operation is None:
|
329
|
-
batch_operation = op
|
330
|
-
else:
|
331
|
-
batch_operation.merge(op)
|
323
|
+
batch_operation = None
|
324
|
+
while self.load_queue.qsize() > 0:
|
325
|
+
op = self.load_queue.get(block=True)
|
332
326
|
if batch_operation is None:
|
333
|
-
|
327
|
+
batch_operation = op
|
328
|
+
else:
|
329
|
+
batch_operation.merge(op)
|
330
|
+
if batch_operation is None:
|
331
|
+
continue
|
334
332
|
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
333
|
+
self.layer_done_counter.reset()
|
334
|
+
for i in range(self.mem_pool_host.layer_num):
|
335
|
+
if self.page_size == 1:
|
336
|
+
flat_data = self.mem_pool_host.get_flat_data_by_layer(
|
337
|
+
batch_operation.host_indices, i
|
338
|
+
)
|
339
|
+
self.mem_pool_device.transfer_per_layer(
|
340
|
+
batch_operation.device_indices, flat_data, i
|
341
|
+
)
|
342
|
+
else:
|
343
|
+
self.mem_pool_host.load_page_per_layer(
|
344
|
+
batch_operation.host_indices,
|
345
|
+
batch_operation.device_indices,
|
346
|
+
self.mem_pool_device,
|
347
|
+
i,
|
348
|
+
)
|
349
|
+
self.load_stream.synchronize()
|
350
|
+
self.layer_done_counter.increment()
|
351
|
+
|
352
|
+
self.mem_pool_host.complete_io(batch_operation.host_indices)
|
353
|
+
for node_id in batch_operation.node_ids:
|
354
|
+
if node_id != 0:
|
355
|
+
self.ack_load_queue.put(node_id)
|
358
356
|
|
359
357
|
def write_aux_func(self, no_wait=False):
|
360
358
|
"""
|
361
359
|
Auxiliary function to prepare the buffer for write operations.
|
362
360
|
"""
|
361
|
+
torch.cuda.set_stream(self.write_stream)
|
363
362
|
|
364
363
|
def _to_op(op_):
|
365
364
|
assert op_.device_indices.is_cuda, "Device indices should be on GPU"
|
@@ -370,44 +369,42 @@ class HiCacheController:
|
|
370
369
|
return op_
|
371
370
|
|
372
371
|
buffer = None
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
// self.write_buffer.max_buffer_size
|
380
|
-
)
|
372
|
+
while not self.stop_event.is_set():
|
373
|
+
try:
|
374
|
+
operation = self.write_queue.get(block=True, timeout=1)
|
375
|
+
factor = (
|
376
|
+
len(operation.device_indices) // self.write_buffer.max_buffer_size
|
377
|
+
)
|
381
378
|
|
382
|
-
|
383
|
-
|
384
|
-
_to_op(buffer)
|
385
|
-
buffer = None
|
386
|
-
|
387
|
-
if factor < 2:
|
388
|
-
_to_op(operation)
|
389
|
-
else:
|
390
|
-
split_ops = operation.split(factor)
|
391
|
-
for op_ in split_ops:
|
392
|
-
_to_op(op_)
|
393
|
-
continue
|
394
|
-
|
395
|
-
if buffer is None:
|
396
|
-
buffer = operation
|
397
|
-
else:
|
398
|
-
buffer.merge(operation)
|
399
|
-
if (
|
400
|
-
no_wait
|
401
|
-
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
|
402
|
-
or self.write_queue.empty()
|
403
|
-
or self.write_buffer.empty()
|
404
|
-
):
|
379
|
+
if factor >= 1:
|
380
|
+
if buffer is not None:
|
405
381
|
_to_op(buffer)
|
406
382
|
buffer = None
|
407
|
-
|
383
|
+
|
384
|
+
if factor < 2:
|
385
|
+
_to_op(operation)
|
386
|
+
else:
|
387
|
+
split_ops = operation.split(factor)
|
388
|
+
for op_ in split_ops:
|
389
|
+
_to_op(op_)
|
408
390
|
continue
|
409
|
-
|
410
|
-
|
391
|
+
|
392
|
+
if buffer is None:
|
393
|
+
buffer = operation
|
394
|
+
else:
|
395
|
+
buffer.merge(operation)
|
396
|
+
if (
|
397
|
+
no_wait
|
398
|
+
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
|
399
|
+
or self.write_queue.empty()
|
400
|
+
or self.write_buffer.empty()
|
401
|
+
):
|
402
|
+
_to_op(buffer)
|
403
|
+
buffer = None
|
404
|
+
except Empty:
|
405
|
+
continue
|
406
|
+
except Exception as e:
|
407
|
+
logger.error(e)
|
411
408
|
|
412
409
|
def load_aux_func(self):
|
413
410
|
"""
|
@@ -484,19 +481,18 @@ class HiCacheController:
|
|
484
481
|
aux_thread.join()
|
485
482
|
|
486
483
|
def load_thread_func_buffer(self):
|
484
|
+
torch.cuda.set_stream(self.load_stream)
|
487
485
|
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
|
488
486
|
aux_thread.start()
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
if node_id != 0:
|
499
|
-
self.ack_load_queue.put(node_id)
|
487
|
+
while not self.stop_event.is_set():
|
488
|
+
operation = self.load_buffer.get()
|
489
|
+
if operation is None:
|
490
|
+
continue
|
491
|
+
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
492
|
+
self.mem_pool_host.complete_io(operation.host_indices)
|
493
|
+
for node_id in operation.node_ids:
|
494
|
+
if node_id != 0:
|
495
|
+
self.ack_load_queue.put(node_id)
|
500
496
|
aux_thread.join()
|
501
497
|
|
502
498
|
def evict_device(
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -790,6 +790,16 @@ class ResumeMemoryOccupationReqOutput:
|
|
790
790
|
pass
|
791
791
|
|
792
792
|
|
793
|
+
@dataclass
|
794
|
+
class SlowDownReqInput:
|
795
|
+
forward_sleep_time: Optional[float]
|
796
|
+
|
797
|
+
|
798
|
+
@dataclass
|
799
|
+
class SlowDownReqOutput:
|
800
|
+
pass
|
801
|
+
|
802
|
+
|
793
803
|
@dataclass
|
794
804
|
class AbortReq:
|
795
805
|
# The request id
|
@@ -8,6 +8,7 @@ from typing import List, Optional
|
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
import PIL
|
11
|
+
import torch
|
11
12
|
from PIL import Image
|
12
13
|
from transformers import BaseImageProcessorFast
|
13
14
|
|
@@ -89,6 +90,10 @@ class BaseMultimodalProcessor(ABC):
|
|
89
90
|
return_tensors="pt",
|
90
91
|
**kwargs,
|
91
92
|
)
|
93
|
+
if "pixel_values" in result and isinstance(
|
94
|
+
result["pixel_values"], torch.Tensor
|
95
|
+
):
|
96
|
+
result["pixel_values"] = result["pixel_values"].to("cpu")
|
92
97
|
return result
|
93
98
|
|
94
99
|
@abstractmethod
|