sglang 0.5.2rc0__py3-none-any.whl → 0.5.2rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/configs/model_config.py +2 -1
- sglang/srt/distributed/parallel_state.py +3 -1
- sglang/srt/entrypoints/engine.py +1 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
- sglang/srt/layers/moe/ep_moe/layer.py +2 -7
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +8 -0
- sglang/srt/layers/quantization/w4afp8.py +30 -25
- sglang/srt/managers/detokenizer_manager.py +0 -34
- sglang/srt/managers/multi_tokenizer_mixin.py +44 -6
- sglang/srt/managers/scheduler.py +3 -0
- sglang/srt/mem_cache/hiradix_cache.py +19 -3
- sglang/srt/mem_cache/memory_pool_host.py +2 -0
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +27 -6
- sglang/srt/models/deepseek_v2.py +5 -0
- sglang/srt/models/gpt_oss.py +5 -4
- sglang/srt/models/longcat_flash.py +26 -15
- sglang/srt/models/longcat_flash_nextn.py +23 -15
- sglang/srt/utils.py +0 -10
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/version.py +1 -1
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc1.dist-info}/METADATA +2 -2
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc1.dist-info}/RECORD +32 -29
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc1.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc1.dist-info}/top_level.txt +0 -0
@@ -175,6 +175,8 @@ class FusedMoE(torch.nn.Module):
|
|
175
175
|
self.moe_tp_rank = get_moe_tensor_parallel_rank()
|
176
176
|
assert num_experts % self.moe_ep_size == 0
|
177
177
|
self.num_local_experts = num_experts // self.moe_ep_size
|
178
|
+
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
|
179
|
+
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
|
178
180
|
if self.moe_ep_size > 1:
|
179
181
|
# TODO(ch-wan): support shared experts fusion
|
180
182
|
# Create a tensor of size num_experts filled with -1
|
@@ -593,8 +595,9 @@ class FusedMoE(torch.nn.Module):
|
|
593
595
|
|
594
596
|
if (
|
595
597
|
"compressed" in self.quant_method.__class__.__name__.lower()
|
596
|
-
|
597
|
-
and (param.data[expert_id]
|
598
|
+
or "w4afp8" in self.quant_config.get_name()
|
599
|
+
and (param.data[expert_id] != 1).any()
|
600
|
+
and ((param.data[expert_id] - loaded_weight).abs() > 1e-5).any()
|
598
601
|
):
|
599
602
|
raise ValueError(
|
600
603
|
"input_scales of w1 and w3 of a layer "
|
@@ -0,0 +1,87 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Tuple
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import triton
|
7
|
+
|
8
|
+
from sglang.srt.utils import is_cuda, is_hip
|
9
|
+
|
10
|
+
_is_cuda = is_cuda()
|
11
|
+
_is_hip = is_hip()
|
12
|
+
|
13
|
+
if _is_cuda or _is_hip:
|
14
|
+
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
15
|
+
|
16
|
+
|
17
|
+
def moe_align_block_size(
|
18
|
+
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
19
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
20
|
+
"""
|
21
|
+
Aligns the token distribution across experts to be compatible with block
|
22
|
+
size for matrix multiplication.
|
23
|
+
|
24
|
+
Parameters:
|
25
|
+
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
|
26
|
+
top-k expert indices for each token.
|
27
|
+
- block_size: The block size used in block matrix multiplication.
|
28
|
+
- num_experts: The total number of experts.
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
- sorted_token_ids: A tensor containing the sorted token indices according
|
32
|
+
to their allocated expert.
|
33
|
+
- expert_ids: A tensor indicating the assigned expert index for each block.
|
34
|
+
- num_tokens_post_padded: The total number of tokens after padding,
|
35
|
+
ensuring divisibility by block_size.
|
36
|
+
|
37
|
+
This function pads the number of tokens that each expert needs to process
|
38
|
+
so that it is divisible by block_size.
|
39
|
+
Padding ensures that during block matrix multiplication, the dimensions
|
40
|
+
align correctly.
|
41
|
+
|
42
|
+
Example:
|
43
|
+
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
|
44
|
+
block_size = 4, and num_experts = 4:
|
45
|
+
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
|
46
|
+
with each expert needing to process 3 tokens.
|
47
|
+
- As block_size is 4, we pad 1 token for each expert.
|
48
|
+
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
|
49
|
+
- Then append padding tokens [12, 12, 12, 12] for each block.
|
50
|
+
- After sorting by expert index, we obtain token_ids
|
51
|
+
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
|
52
|
+
Tokens 12 are non-existent (padding) and are ignored in
|
53
|
+
the subsequent matrix multiplication.
|
54
|
+
- The padding ensures that the total number of tokens is now divisible
|
55
|
+
by block_size for proper block matrix operations.
|
56
|
+
"""
|
57
|
+
max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
|
58
|
+
sorted_ids = torch.empty(
|
59
|
+
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
60
|
+
)
|
61
|
+
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
62
|
+
expert_ids = torch.empty(
|
63
|
+
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
64
|
+
)
|
65
|
+
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
66
|
+
|
67
|
+
# In EP, expert_ids for filtered experts are -1. We have num_experts + 1 ids in total.
|
68
|
+
cumsum_buffer = torch.empty(
|
69
|
+
(num_experts + 2,), dtype=torch.int32, device=topk_ids.device
|
70
|
+
)
|
71
|
+
|
72
|
+
# Threshold based on benchmark results
|
73
|
+
fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096
|
74
|
+
if not fuse_sorted_ids_padding:
|
75
|
+
sorted_ids.fill_(topk_ids.numel())
|
76
|
+
|
77
|
+
sgl_moe_align_block_size(
|
78
|
+
topk_ids,
|
79
|
+
num_experts + 1,
|
80
|
+
block_size,
|
81
|
+
sorted_ids,
|
82
|
+
expert_ids,
|
83
|
+
num_tokens_post_pad,
|
84
|
+
cumsum_buffer,
|
85
|
+
fuse_sorted_ids_padding,
|
86
|
+
)
|
87
|
+
return sorted_ids, expert_ids, num_tokens_post_pad
|
@@ -132,9 +132,17 @@ def _compile_deep_gemm_one_type_all(
|
|
132
132
|
kernel_type, max_m=max(m_list), n=n, k=k, num_groups=num_groups
|
133
133
|
)
|
134
134
|
|
135
|
+
old_compile_mode = deep_gemm.get_compile_mode()
|
136
|
+
deep_gemm.set_compile_mode(1)
|
135
137
|
# TODO can use multi thread
|
136
138
|
for m in tqdm(m_list, desc=f"DeepGEMM warmup"):
|
137
139
|
executor.execute(m=m)
|
140
|
+
deep_gemm.set_compile_mode(old_compile_mode)
|
141
|
+
|
142
|
+
# clean up input buffers
|
143
|
+
torch.cuda.current_stream().synchronize()
|
144
|
+
del executor
|
145
|
+
torch.cuda.empty_cache()
|
138
146
|
|
139
147
|
|
140
148
|
class _BaseWarmupExecutor:
|
@@ -1,12 +1,14 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import logging
|
4
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
4
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
5
5
|
|
6
6
|
import torch
|
7
7
|
from torch.nn import Module
|
8
8
|
from torch.nn.parameter import Parameter
|
9
9
|
|
10
|
+
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
11
|
+
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
10
12
|
from sglang.srt.layers.quantization.base_config import (
|
11
13
|
FusedMoEMethodBase,
|
12
14
|
QuantizationConfig,
|
@@ -91,12 +93,13 @@ class W4AFp8Config(QuantizationConfig):
|
|
91
93
|
from sglang.srt.layers.linear import LinearBase
|
92
94
|
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
93
95
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
96
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
94
97
|
|
95
98
|
if isinstance(layer, LinearBase):
|
96
99
|
if is_layer_skipped(prefix, self.ignored_layers):
|
97
100
|
return UnquantizedLinearMethod()
|
98
101
|
return Fp8LinearMethod(self)
|
99
|
-
elif isinstance(layer,
|
102
|
+
elif isinstance(layer, FusedMoE):
|
100
103
|
return W4AFp8MoEMethod(self)
|
101
104
|
return None
|
102
105
|
|
@@ -104,8 +107,24 @@ class W4AFp8Config(QuantizationConfig):
|
|
104
107
|
return []
|
105
108
|
|
106
109
|
|
107
|
-
|
110
|
+
def interleave_scales(scales: torch.Tensor) -> torch.Tensor:
|
111
|
+
"""Interleave scales in groups of 4 similar to TRT-LLM implementation."""
|
112
|
+
s_shape = scales.shape
|
113
|
+
# Reshape to separate groups of 4
|
114
|
+
alignment = 4 if s_shape[2] % 4 == 0 else 1
|
115
|
+
scales_interleaved = scales.reshape(
|
116
|
+
s_shape[0], s_shape[1], (s_shape[2] // alignment), alignment
|
117
|
+
)
|
118
|
+
# Permute dimensions to interleave
|
119
|
+
scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
|
120
|
+
# Reshape back to original dimensions but with interleaved values
|
121
|
+
scales_interleaved = scales_interleaved.reshape(
|
122
|
+
s_shape[0], s_shape[2] // alignment, s_shape[1] * alignment
|
123
|
+
)
|
124
|
+
return scales_interleaved.contiguous()
|
125
|
+
|
108
126
|
|
127
|
+
class W4AFp8MoEMethod(FusedMoEMethodBase):
|
109
128
|
def __init__(self, quant_config: W4AFp8Config):
|
110
129
|
self.quant_config = quant_config
|
111
130
|
|
@@ -234,33 +253,18 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
234
253
|
|
235
254
|
return
|
236
255
|
|
237
|
-
def _interleave_scales(self, scales: torch.Tensor) -> torch.Tensor:
|
238
|
-
"""Interleave scales in groups of 4 similar to TRT-LLM implementation."""
|
239
|
-
s_shape = scales.shape
|
240
|
-
# Reshape to separate groups of 4
|
241
|
-
scales_interleaved = scales.reshape(
|
242
|
-
s_shape[0], s_shape[1], (s_shape[2] // 4), 4
|
243
|
-
)
|
244
|
-
# Permute dimensions to interleave
|
245
|
-
scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
|
246
|
-
# Reshape back to original dimensions but with interleaved values
|
247
|
-
scales_interleaved = scales_interleaved.reshape(
|
248
|
-
s_shape[0], s_shape[2] // 4, s_shape[1] * 4
|
249
|
-
)
|
250
|
-
return scales_interleaved.contiguous()
|
251
|
-
|
252
256
|
def process_weights_after_loading(self, layer: Module) -> None:
|
253
257
|
dtype = torch.bfloat16
|
254
258
|
device = layer.w2_weight.device
|
255
259
|
|
256
260
|
# Interleave w13_weight_scale (gate_up_proj)
|
257
261
|
w13_weight_scale = layer.w13_weight_scale_inv.to(dtype)
|
258
|
-
w13_weight_scale =
|
262
|
+
w13_weight_scale = interleave_scales(w13_weight_scale)
|
259
263
|
layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False)
|
260
264
|
|
261
265
|
# Interleave w2_weight_scale (down_proj)
|
262
266
|
w2_weight_scale = layer.w2_weight_scale_inv.to(dtype)
|
263
|
-
w2_weight_scale =
|
267
|
+
w2_weight_scale = interleave_scales(w2_weight_scale)
|
264
268
|
layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False)
|
265
269
|
|
266
270
|
# Process input scales
|
@@ -291,11 +295,12 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
291
295
|
|
292
296
|
topk_weights, topk_ids, _ = topk_output
|
293
297
|
local_topk_ids = topk_ids
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
298
|
+
if get_moe_expert_parallel_world_size() > 1:
|
299
|
+
local_topk_ids = torch.where(
|
300
|
+
topk_ids == -1,
|
301
|
+
layer.num_experts,
|
302
|
+
topk_ids,
|
303
|
+
)
|
299
304
|
|
300
305
|
output = cutlass_w4a8_moe(
|
301
306
|
layer.start_expert_id,
|
@@ -39,7 +39,6 @@ from sglang.srt.server_args import PortArgs, ServerArgs
|
|
39
39
|
from sglang.srt.utils import (
|
40
40
|
configure_logger,
|
41
41
|
freeze_gc,
|
42
|
-
get_worker_ids_from_req_rids,
|
43
42
|
get_zmq_socket,
|
44
43
|
kill_itself_when_parent_died,
|
45
44
|
)
|
@@ -120,39 +119,6 @@ class DetokenizerManager(MultiTokenizerMixin):
|
|
120
119
|
if output is not None:
|
121
120
|
self.send_to_tokenizer.send_pyobj(output)
|
122
121
|
|
123
|
-
def multi_tokenizer_manager_event_loop(self):
|
124
|
-
"""The event loop that handles requests, for multi tokenizer manager mode only"""
|
125
|
-
self.create_sockets_mapping()
|
126
|
-
while True:
|
127
|
-
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
128
|
-
output = self._request_dispatcher(recv_obj)
|
129
|
-
if output is None:
|
130
|
-
continue
|
131
|
-
# Extract worker_id from rid
|
132
|
-
if isinstance(recv_obj.rids, list):
|
133
|
-
worker_ids = get_worker_ids_from_req_rids(recv_obj.rids)
|
134
|
-
else:
|
135
|
-
raise RuntimeError(
|
136
|
-
f"for tokenizer_worker_num > 1, recv_obj.rids must be a list"
|
137
|
-
)
|
138
|
-
|
139
|
-
# Send data using the corresponding socket
|
140
|
-
for i, worker_id in enumerate(worker_ids):
|
141
|
-
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
142
|
-
if self.register_tokenizer_ipc(recv_obj, worker_id):
|
143
|
-
logger.info(
|
144
|
-
f"DetokenizerManager Created ZMQ socket for worker {worker_id}"
|
145
|
-
)
|
146
|
-
continue
|
147
|
-
else:
|
148
|
-
if worker_id not in self.tokenizer_mapping:
|
149
|
-
logger.error(
|
150
|
-
f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
|
151
|
-
)
|
152
|
-
continue
|
153
|
-
new_output = self._handle_output_by_index(output, i)
|
154
|
-
self.tokenizer_mapping[worker_id].send_pyobj(new_output)
|
155
|
-
|
156
122
|
def trim_matched_stop(
|
157
123
|
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
|
158
124
|
):
|
@@ -37,11 +37,7 @@ from sglang.srt.managers.io_struct import (
|
|
37
37
|
)
|
38
38
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager, _Communicator
|
39
39
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
40
|
-
from sglang.srt.utils import
|
41
|
-
get_worker_ids_from_req_rids,
|
42
|
-
get_zmq_socket,
|
43
|
-
kill_process_tree,
|
44
|
-
)
|
40
|
+
from sglang.srt.utils import get_zmq_socket, kill_process_tree
|
45
41
|
from sglang.utils import get_exception_traceback
|
46
42
|
|
47
43
|
logger = logging.getLogger(__name__)
|
@@ -344,6 +340,48 @@ class MultiTokenizerMixin:
|
|
344
340
|
new_output = output
|
345
341
|
return new_output
|
346
342
|
|
343
|
+
def get_worker_ids_from_req_rids(self, rids):
|
344
|
+
if isinstance(rids, list):
|
345
|
+
worker_ids = [int(rid.split("_")[0]) for rid in rids]
|
346
|
+
elif isinstance(rids, str):
|
347
|
+
worker_ids = [int(rids.split("_")[0])]
|
348
|
+
else:
|
349
|
+
worker_ids = []
|
350
|
+
return worker_ids
|
351
|
+
|
352
|
+
def multi_tokenizer_manager_event_loop(self):
|
353
|
+
"""The event loop that handles requests, for multi tokenizer manager mode only"""
|
354
|
+
self.create_sockets_mapping()
|
355
|
+
while True:
|
356
|
+
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
357
|
+
output = self._request_dispatcher(recv_obj)
|
358
|
+
if output is None:
|
359
|
+
continue
|
360
|
+
# Extract worker_id from rid
|
361
|
+
if isinstance(recv_obj.rids, list):
|
362
|
+
worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
|
363
|
+
else:
|
364
|
+
raise RuntimeError(
|
365
|
+
f"for tokenizer_worker_num > 1, recv_obj.rids must be a list"
|
366
|
+
)
|
367
|
+
|
368
|
+
# Send data using the corresponding socket
|
369
|
+
for i, worker_id in enumerate(worker_ids):
|
370
|
+
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
371
|
+
if self.register_tokenizer_ipc(recv_obj, worker_id):
|
372
|
+
logger.info(
|
373
|
+
f"DetokenizerManager Created ZMQ socket for worker {worker_id}"
|
374
|
+
)
|
375
|
+
continue
|
376
|
+
else:
|
377
|
+
if worker_id not in self.tokenizer_mapping:
|
378
|
+
logger.error(
|
379
|
+
f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
|
380
|
+
)
|
381
|
+
continue
|
382
|
+
new_output = self._handle_output_by_index(output, i)
|
383
|
+
self.tokenizer_mapping[worker_id].send_pyobj(new_output)
|
384
|
+
|
347
385
|
def clear_tokenizer_mapping(self):
|
348
386
|
if hasattr(self, "tokenizer_mapping"):
|
349
387
|
for socket in self.tokenizer_mapping.values():
|
@@ -406,7 +444,7 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
|
|
406
444
|
worker_ids = [recv_obj.worker_id]
|
407
445
|
recv_obj = recv_obj.obj
|
408
446
|
else:
|
409
|
-
worker_ids = get_worker_ids_from_req_rids(recv_obj.rids)
|
447
|
+
worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
|
410
448
|
|
411
449
|
if len(worker_ids) == 0:
|
412
450
|
logger.error(f"Cannot find worker_id from rids {recv_obj.rids}")
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -2403,6 +2403,9 @@ class Scheduler(
|
|
2403
2403
|
# This only works for requests that have not started anything.
|
2404
2404
|
# We still need to send something back to TokenizerManager to clean up the state.
|
2405
2405
|
req = self.waiting_queue.pop(i)
|
2406
|
+
if self.enable_hicache_storage:
|
2407
|
+
# to release prefetch events associated with the request
|
2408
|
+
self.tree_cache.release_aborted_request(req.rid)
|
2406
2409
|
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
|
2407
2410
|
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
|
2408
2411
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
@@ -468,9 +468,9 @@ class HiRadixCache(RadixCache):
|
|
468
468
|
|
469
469
|
# todo: more policies for prefetch progress such as timeout
|
470
470
|
# the current policy is to prefetch with best effort and terminate when queuing is over
|
471
|
-
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch
|
471
|
+
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop(
|
472
472
|
req_id
|
473
|
-
|
473
|
+
)
|
474
474
|
|
475
475
|
if operation.host_indices is None:
|
476
476
|
# prefetch has not been issued due to insufficient host memory
|
@@ -512,7 +512,6 @@ class HiRadixCache(RadixCache):
|
|
512
512
|
host_indices[min_completed_tokens:completed_tokens]
|
513
513
|
)
|
514
514
|
last_host_node.release_host()
|
515
|
-
del self.ongoing_prefetch[req_id]
|
516
515
|
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
517
516
|
|
518
517
|
return True
|
@@ -771,3 +770,20 @@ class HiRadixCache(RadixCache):
|
|
771
770
|
if not cur_child.evicted:
|
772
771
|
stack.append(cur_child)
|
773
772
|
return ret_list
|
773
|
+
|
774
|
+
def release_aborted_request(self, rid: str):
|
775
|
+
if rid not in self.ongoing_prefetch:
|
776
|
+
return
|
777
|
+
|
778
|
+
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop(
|
779
|
+
rid
|
780
|
+
)
|
781
|
+
if operation.host_indices is None:
|
782
|
+
return
|
783
|
+
|
784
|
+
completed_tokens, _ = self.cache_controller.terminate_prefetch(operation)
|
785
|
+
if self.tp_world_size > 1:
|
786
|
+
torch.distributed.barrier(group=self.tp_group)
|
787
|
+
last_host_node.release_host()
|
788
|
+
self.cache_controller.append_host_mem_release(host_indices[:completed_tokens])
|
789
|
+
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
@@ -467,6 +467,7 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
467
467
|
ptr_list = []
|
468
468
|
key_list = []
|
469
469
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
470
|
+
indices = indices.tolist()
|
470
471
|
v_offset = (
|
471
472
|
self.layer_num
|
472
473
|
* self.size
|
@@ -706,6 +707,7 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
706
707
|
ptr_list = []
|
707
708
|
key_list = []
|
708
709
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
710
|
+
indices = indices.tolist()
|
709
711
|
for index in range(0, len(indices), self.page_size):
|
710
712
|
k_ptr = (
|
711
713
|
kv_buffer_data_ptr
|
@@ -4,10 +4,12 @@ import json
|
|
4
4
|
import logging
|
5
5
|
import threading
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import Dict, List, Optional, Tuple
|
7
|
+
from typing import Dict, List, Optional, OrderedDict, Tuple
|
8
8
|
|
9
|
+
import orjson
|
9
10
|
import requests
|
10
|
-
from fastapi import FastAPI, HTTPException, Request,
|
11
|
+
from fastapi import FastAPI, HTTPException, Request, Response
|
12
|
+
from fastapi.responses import ORJSONResponse
|
11
13
|
from requests.adapters import HTTPAdapter
|
12
14
|
from urllib3.util.retry import Retry
|
13
15
|
|
@@ -24,10 +26,10 @@ class RankMetadata:
|
|
24
26
|
"""Holds all metadata for a single rank."""
|
25
27
|
|
26
28
|
def __init__(self, num_pages: int):
|
27
|
-
self.lock = threading.
|
29
|
+
self.lock = threading.Lock()
|
28
30
|
self.num_pages = num_pages
|
29
31
|
self.free_pages: List[int] = list(range(num_pages))
|
30
|
-
self.key_to_index:
|
32
|
+
self.key_to_index: OrderedDict[str, int] = OrderedDict()
|
31
33
|
# Todo: Support multi files for HF3FS
|
32
34
|
|
33
35
|
def exists_keys(self, keys: List[str]) -> List[bool]:
|
@@ -46,16 +48,18 @@ class RankMetadata:
|
|
46
48
|
for i, (key, prefix_key) in enumerate(keys):
|
47
49
|
if key in self.key_to_index:
|
48
50
|
results[i] = (True, self.key_to_index[key])
|
51
|
+
self.key_to_index.move_to_end(key)
|
49
52
|
else:
|
50
53
|
new_keys_to_process.append((i, key, prefix_key))
|
51
54
|
|
52
55
|
# Todo: Implementing data eviction logic after HiCache supports prefix information pass-through
|
53
56
|
for i, key, prefix_key in new_keys_to_process:
|
54
57
|
if len(self.free_pages) > 0:
|
55
|
-
|
56
|
-
results[i] = (False, page_idx)
|
58
|
+
page_index = self.free_pages.pop()
|
57
59
|
else:
|
58
|
-
|
60
|
+
page_index = self.key_to_index.popitem(last=False)[1]
|
61
|
+
|
62
|
+
results[i] = (False, page_index)
|
59
63
|
|
60
64
|
return results
|
61
65
|
|
@@ -68,6 +72,7 @@ class RankMetadata:
|
|
68
72
|
with self.lock:
|
69
73
|
for key, page_index in written_keys_to_confirm:
|
70
74
|
self.key_to_index[key] = page_index
|
75
|
+
self.key_to_index.move_to_end(key)
|
71
76
|
|
72
77
|
for page_index in pages_to_release:
|
73
78
|
if page_index not in self.free_pages:
|
@@ -94,7 +99,14 @@ class RankMetadata:
|
|
94
99
|
def get_page_indices(self, keys: List[str]) -> List[Optional[int]]:
|
95
100
|
"""Get page indices for keys."""
|
96
101
|
with self.lock:
|
97
|
-
|
102
|
+
results = []
|
103
|
+
for key in keys:
|
104
|
+
if key in self.key_to_index:
|
105
|
+
results.append(self.key_to_index[key])
|
106
|
+
self.key_to_index.move_to_end(key)
|
107
|
+
else:
|
108
|
+
results.append(None)
|
109
|
+
return results
|
98
110
|
|
99
111
|
|
100
112
|
class GlobalMetadataState:
|
@@ -182,7 +194,8 @@ class Hf3fsMetadataServer:
|
|
182
194
|
|
183
195
|
def __init__(self, persistence_path: Optional[str] = None, save_interval: int = 60):
|
184
196
|
self.state = GlobalMetadataState(persistence_path, save_interval)
|
185
|
-
self.app = FastAPI()
|
197
|
+
self.app = FastAPI(default_response_class=ORJSONResponse)
|
198
|
+
|
186
199
|
self._setup_routes()
|
187
200
|
|
188
201
|
def _setup_routes(self):
|
@@ -199,17 +212,25 @@ class Hf3fsMetadataServer:
|
|
199
212
|
|
200
213
|
def get_rank_metadata(self, rank: int) -> RankMetadata:
|
201
214
|
"""Get rank metadata with proper error handling."""
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
215
|
+
if rank not in self.state.ranks:
|
216
|
+
raise HTTPException(
|
217
|
+
status_code=404,
|
218
|
+
detail=f"Rank {rank} not initialized. Please call /{rank}/initialize first.",
|
219
|
+
)
|
220
|
+
return self.state.ranks[rank]
|
221
|
+
|
222
|
+
async def _read_json(self, request: Request) -> dict:
|
223
|
+
"""Parse request JSON using orjson if available."""
|
224
|
+
body = await request.body()
|
225
|
+
return orjson.loads(body)
|
226
|
+
|
227
|
+
def _json_response(self, content: dict):
|
228
|
+
"""Return ORJSONResponse when available to bypass jsonable_encoder."""
|
229
|
+
return ORJSONResponse(content)
|
209
230
|
|
210
231
|
async def initialize(self, rank: int, request: Request):
|
211
232
|
"""Initialize a rank with specified number of pages."""
|
212
|
-
data = await
|
233
|
+
data = await self._read_json(request)
|
213
234
|
num_pages = data["num_pages"]
|
214
235
|
with self.state.global_lock:
|
215
236
|
if rank in self.state.ranks:
|
@@ -223,57 +244,55 @@ class Hf3fsMetadataServer:
|
|
223
244
|
else:
|
224
245
|
logging.info(f"Initializing new Rank {rank} with {num_pages} pages.")
|
225
246
|
self.state.ranks[rank] = RankMetadata(num_pages)
|
226
|
-
return
|
247
|
+
return Response(status_code=204)
|
227
248
|
|
228
249
|
async def exists(self, rank: int, request: Request):
|
229
250
|
"""Check if keys exist in metadata."""
|
230
|
-
data = await
|
251
|
+
data = await self._read_json(request)
|
231
252
|
keys = data["keys"]
|
232
253
|
metadata = self.get_rank_metadata(rank)
|
233
254
|
results = metadata.exists_keys(keys)
|
234
|
-
return {"exists": results}
|
255
|
+
return self._json_response({"exists": results})
|
235
256
|
|
236
257
|
async def reserve_and_allocate_page_indices(self, rank: int, request: Request):
|
237
258
|
"""Reserve and allocate page indices for keys."""
|
238
|
-
data = await
|
259
|
+
data = await self._read_json(request)
|
239
260
|
metadata = self.get_rank_metadata(rank)
|
240
261
|
keys = data["keys"]
|
241
262
|
results = metadata.reserve_and_allocate_page_indices(keys)
|
242
|
-
return {"indices": results}
|
263
|
+
return self._json_response({"indices": results})
|
243
264
|
|
244
265
|
async def confirm_write(self, rank: int, request: Request):
|
245
266
|
"""Confirm write operations and release pages."""
|
246
|
-
data = await
|
267
|
+
data = await self._read_json(request)
|
247
268
|
metadata = self.get_rank_metadata(rank)
|
248
269
|
success_written_keys = data.get("written_keys_to_confirm", [])
|
249
270
|
released_pages = data.get("pages_to_release", [])
|
250
271
|
|
251
272
|
metadata.confirm_write(success_written_keys, released_pages)
|
252
273
|
|
253
|
-
return
|
254
|
-
"message": f"Rank {rank}: Write confirmed for {len(success_written_keys)} keys. {len(released_pages)} pages released."
|
255
|
-
}
|
274
|
+
return Response(status_code=204)
|
256
275
|
|
257
276
|
async def delete_keys(self, rank: int, request: Request):
|
258
277
|
"""Delete keys from metadata."""
|
259
|
-
data = await
|
278
|
+
data = await self._read_json(request)
|
260
279
|
metadata = self.get_rank_metadata(rank)
|
261
280
|
count = metadata.delete_keys(data["keys"])
|
262
|
-
return
|
281
|
+
return Response(status_code=204)
|
263
282
|
|
264
283
|
async def clear(self, rank: int):
|
265
284
|
"""Clear all metadata for a rank."""
|
266
285
|
metadata = self.get_rank_metadata(rank)
|
267
286
|
metadata.clear_all()
|
268
|
-
return
|
287
|
+
return Response(status_code=204)
|
269
288
|
|
270
289
|
async def get_page_indices(self, rank: int, request: Request):
|
271
290
|
"""Get page indices for keys."""
|
272
|
-
data = await
|
291
|
+
data = await self._read_json(request)
|
273
292
|
metadata = self.get_rank_metadata(rank)
|
274
293
|
keys = data["keys"]
|
275
294
|
results = metadata.get_page_indices(keys)
|
276
|
-
return {"indices": results}
|
295
|
+
return self._json_response({"indices": results})
|
277
296
|
|
278
297
|
def run(self, host: str = "0.0.0.0", port: int = 18000):
|
279
298
|
"""Run the metadata server."""
|
@@ -309,14 +328,22 @@ class Hf3fsGlobalMetadataClient(Hf3fsMetadataInterface):
|
|
309
328
|
status_forcelist=[500, 502, 503, 504],
|
310
329
|
allowed_methods=["GET", "POST"],
|
311
330
|
)
|
312
|
-
adapter = HTTPAdapter(
|
331
|
+
adapter = HTTPAdapter(
|
332
|
+
max_retries=retry_strategy, pool_connections=256, pool_maxsize=256
|
333
|
+
)
|
313
334
|
self._session.mount("http://", adapter)
|
314
335
|
|
315
336
|
def _post(self, endpoint: str, json_data: dict) -> dict:
|
316
337
|
try:
|
317
|
-
|
338
|
+
url = f"{self.base_url}/{endpoint}"
|
339
|
+
headers = {"Content-Type": "application/json"}
|
340
|
+
payload = orjson.dumps(json_data) # type: ignore[union-attr]
|
341
|
+
response = self._session.post(url, data=payload, headers=headers)
|
318
342
|
response.raise_for_status()
|
319
|
-
|
343
|
+
|
344
|
+
if response.status_code == 204 or not response.content:
|
345
|
+
return {}
|
346
|
+
return orjson.loads(response.content) # type: ignore[union-attr]
|
320
347
|
except requests.exceptions.RequestException as e:
|
321
348
|
logging.error(f"Failed to POST to {endpoint} after retries: {e}")
|
322
349
|
raise RuntimeError(f"Failed to connect to metadata server: {e}") from e
|