sglang 0.5.2rc1__py3-none-any.whl → 0.5.2rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/disaggregation/mini_lb.py +2 -2
- sglang/srt/distributed/parallel_state.py +43 -40
- sglang/srt/entrypoints/http_server.py +5 -1
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +3 -3
- sglang/srt/entrypoints/openai/serving_completions.py +3 -1
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -1
- sglang/srt/entrypoints/openai/serving_responses.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/utils.py +0 -1
- sglang/srt/layers/quantization/modelopt_quant.py +35 -2
- sglang/srt/layers/quantization/mxfp4.py +4 -1
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +0 -18
- sglang/srt/managers/cache_controller.py +42 -39
- sglang/srt/managers/multi_tokenizer_mixin.py +4 -0
- sglang/srt/managers/schedule_policy.py +3 -2
- sglang/srt/managers/scheduler.py +4 -100
- sglang/srt/managers/scheduler_metrics_mixin.py +113 -7
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +15 -10
- sglang/srt/mem_cache/hiradix_cache.py +5 -5
- sglang/srt/mem_cache/memory_pool_host.py +16 -11
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +10 -2
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +32 -13
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/metrics/collector.py +12 -4
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/forward_batch_info.py +16 -17
- sglang/srt/model_executor/model_runner.py +1 -1
- sglang/srt/models/deepseek_v2.py +240 -36
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/qwen2_moe.py +4 -1
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/server_args.py +79 -2
- sglang/srt/speculative/eagle_worker.py +158 -112
- sglang/srt/utils.py +12 -0
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/METADATA +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/RECORD +65 -61
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,7 @@ import threading
|
|
23
23
|
from multiprocessing import shared_memory
|
24
24
|
from typing import Dict
|
25
25
|
|
26
|
+
import setproctitle
|
26
27
|
import zmq
|
27
28
|
import zmq.asyncio
|
28
29
|
|
@@ -476,6 +477,9 @@ class MultiTokenizerManager(TokenizerManager, MultiTokenizerMixin):
|
|
476
477
|
server_args: ServerArgs,
|
477
478
|
port_args: PortArgs,
|
478
479
|
):
|
480
|
+
setproctitle.setproctitle(
|
481
|
+
f"sglang::http_server/multi_tokenizer_manager:{os.getpid()}"
|
482
|
+
)
|
479
483
|
# prevent init prefill bootstrapserver again
|
480
484
|
disaggregation_mode = server_args.disaggregation_mode
|
481
485
|
server_args.disaggregation_mode = "null"
|
@@ -380,8 +380,9 @@ class PrefillAdder:
|
|
380
380
|
self.log_input_tokens += extend_input_len
|
381
381
|
|
382
382
|
def add_chunked_req(self, req: Req):
|
383
|
-
|
384
|
-
|
383
|
+
_rem_tokens = min(self.rem_chunk_tokens, int(self.rem_total_tokens))
|
384
|
+
truncated = req.extend_input_len > _rem_tokens
|
385
|
+
req.extend_input_len = min(req.extend_input_len, _rem_tokens)
|
385
386
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
386
387
|
self.can_run_list.append(req)
|
387
388
|
self._update_prefill_budget(
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -141,7 +141,7 @@ from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
|
|
141
141
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
142
142
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
143
143
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
144
|
-
from sglang.srt.reasoning_parser import ReasoningParser
|
144
|
+
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
145
145
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
146
146
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
147
147
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
@@ -500,6 +500,7 @@ class Scheduler(
|
|
500
500
|
# Init metrics stats
|
501
501
|
self.init_metrics(tp_rank, pp_rank, dp_rank)
|
502
502
|
self.init_kv_events(server_args.kv_events_config)
|
503
|
+
self.init_dp_balance(dp_balance_meta)
|
503
504
|
|
504
505
|
# Init disaggregation
|
505
506
|
self.disaggregation_mode = DisaggregationMode(
|
@@ -545,15 +546,6 @@ class Scheduler(
|
|
545
546
|
]
|
546
547
|
)
|
547
548
|
|
548
|
-
self.balance_meta = dp_balance_meta
|
549
|
-
if (
|
550
|
-
server_args.enable_dp_attention
|
551
|
-
and server_args.load_balance_method == "minimum_tokens"
|
552
|
-
):
|
553
|
-
assert dp_balance_meta is not None
|
554
|
-
|
555
|
-
self.recv_dp_balance_id_this_term = []
|
556
|
-
|
557
549
|
def init_tokenizer(self):
|
558
550
|
server_args = self.server_args
|
559
551
|
self.is_generation = self.model_config.is_generation
|
@@ -1126,11 +1118,7 @@ class Scheduler(
|
|
1126
1118
|
self,
|
1127
1119
|
recv_req: TokenizedGenerateReqInput,
|
1128
1120
|
):
|
1129
|
-
|
1130
|
-
self.server_args.enable_dp_attention
|
1131
|
-
and self.server_args.load_balance_method == "minimum_tokens"
|
1132
|
-
):
|
1133
|
-
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
|
1121
|
+
self.maybe_update_dp_balance_data(recv_req)
|
1134
1122
|
|
1135
1123
|
# Create a new request
|
1136
1124
|
if (
|
@@ -1568,11 +1556,7 @@ class Scheduler(
|
|
1568
1556
|
|
1569
1557
|
# Handle DP attention
|
1570
1558
|
if need_dp_attn_preparation:
|
1571
|
-
|
1572
|
-
self.server_args.load_balance_method == "minimum_tokens"
|
1573
|
-
and self.forward_ct % 40 == 0
|
1574
|
-
):
|
1575
|
-
self.handle_dp_balance_data(ret)
|
1559
|
+
self.maybe_handle_dp_balance_data()
|
1576
1560
|
ret = self.prepare_mlp_sync_batch(ret)
|
1577
1561
|
|
1578
1562
|
return ret
|
@@ -1897,86 +1881,6 @@ class Scheduler(
|
|
1897
1881
|
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
|
1898
1882
|
)
|
1899
1883
|
|
1900
|
-
def handle_dp_balance_data(self, local_batch: ScheduleBatch):
|
1901
|
-
def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
|
1902
|
-
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
|
1903
|
-
recv_list = self.recv_dp_balance_id_this_term
|
1904
|
-
assert len(recv_list) <= 511, (
|
1905
|
-
"The number of requests received this round is too large. "
|
1906
|
-
"Please increase gather_tensor_size and onfly_info_size."
|
1907
|
-
)
|
1908
|
-
# The maximum size of the tensor used for gathering data from all workers.
|
1909
|
-
gather_tensor_size = 512
|
1910
|
-
|
1911
|
-
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
|
1912
|
-
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
|
1913
|
-
recv_tensor[0] = holding_tokens_list
|
1914
|
-
recv_tensor[1] = len(
|
1915
|
-
recv_list
|
1916
|
-
) # The first element is the length of the list.
|
1917
|
-
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
|
1918
|
-
recv_list, dtype=torch.int32
|
1919
|
-
)
|
1920
|
-
|
1921
|
-
if self.tp_rank == 0:
|
1922
|
-
gathered_list = [
|
1923
|
-
torch.zeros(gather_tensor_size, dtype=torch.int32)
|
1924
|
-
for _ in range(self.balance_meta.num_workers)
|
1925
|
-
]
|
1926
|
-
else:
|
1927
|
-
gathered_list = None
|
1928
|
-
|
1929
|
-
torch.distributed.gather(
|
1930
|
-
recv_tensor, gathered_list, group=self.tp_cpu_group
|
1931
|
-
)
|
1932
|
-
|
1933
|
-
gathered_id_list_per_worker = None
|
1934
|
-
if self.tp_rank == 0:
|
1935
|
-
gathered_id_list_per_worker = []
|
1936
|
-
holding_tokens_list = []
|
1937
|
-
for tensor in gathered_list:
|
1938
|
-
holding_tokens_list.append(tensor[0].item())
|
1939
|
-
list_length = tensor[1].item()
|
1940
|
-
gathered_id_list_per_worker.append(
|
1941
|
-
tensor[2 : list_length + 2].tolist()
|
1942
|
-
)
|
1943
|
-
|
1944
|
-
return gathered_id_list_per_worker, holding_tokens_list
|
1945
|
-
|
1946
|
-
def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
|
1947
|
-
meta = self.balance_meta
|
1948
|
-
|
1949
|
-
with meta.mutex:
|
1950
|
-
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
|
1951
|
-
assert len(new_recv_rid_lists) == len(
|
1952
|
-
onfly_list
|
1953
|
-
), "num_worker not equal"
|
1954
|
-
# 1.Check if the rid received by each worker this round is present in onfly.
|
1955
|
-
# If it is, remove the corresponding onfly item.
|
1956
|
-
worker_id = 0
|
1957
|
-
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
|
1958
|
-
for new_recv_rid in new_recv_rids:
|
1959
|
-
assert (
|
1960
|
-
new_recv_rid in on_fly_reqs
|
1961
|
-
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
|
1962
|
-
del on_fly_reqs[new_recv_rid]
|
1963
|
-
worker_id += 1
|
1964
|
-
# 2. Atomically write local_tokens and onfly into shm under the mutex
|
1965
|
-
meta.set_shared_onfly_info(onfly_list)
|
1966
|
-
meta.set_shared_local_tokens(local_tokens)
|
1967
|
-
|
1968
|
-
holding_tokens = self.get_load()
|
1969
|
-
|
1970
|
-
new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
|
1971
|
-
holding_tokens
|
1972
|
-
)
|
1973
|
-
|
1974
|
-
self.recv_dp_balance_id_this_term.clear()
|
1975
|
-
if self.tp_rank == 0: # only first worker write info
|
1976
|
-
write_shared_dp_balance_info(
|
1977
|
-
new_recv_dp_balance_id_list, holding_token_list
|
1978
|
-
)
|
1979
|
-
|
1980
1884
|
@staticmethod
|
1981
1885
|
def prepare_mlp_sync_batch_raw(
|
1982
1886
|
local_batch: ScheduleBatch,
|
@@ -1,15 +1,24 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import logging
|
2
4
|
import time
|
3
5
|
from collections import defaultdict
|
4
|
-
from typing import List, Optional
|
6
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
7
|
+
|
8
|
+
import torch
|
5
9
|
|
6
10
|
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
|
7
11
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
12
|
+
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
|
8
13
|
from sglang.srt.managers.schedule_policy import PrefillAdder
|
9
14
|
from sglang.srt.managers.scheduler import Req, ScheduleBatch
|
15
|
+
from sglang.srt.managers.utils import DPBalanceMeta
|
10
16
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
11
17
|
from sglang.srt.utils import get_bool_env_var
|
12
18
|
|
19
|
+
if TYPE_CHECKING:
|
20
|
+
from sglang.srt.managers.scheduler import Scheduler
|
21
|
+
|
13
22
|
logger = logging.getLogger(__name__)
|
14
23
|
|
15
24
|
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
|
@@ -28,7 +37,9 @@ class KvMetrics:
|
|
28
37
|
|
29
38
|
|
30
39
|
class SchedulerMetricsMixin:
|
31
|
-
def init_metrics(
|
40
|
+
def init_metrics(
|
41
|
+
self: Scheduler, tp_rank: int, pp_rank: int, dp_rank: Optional[int]
|
42
|
+
):
|
32
43
|
self.last_gen_throughput: float = 0.0
|
33
44
|
self.last_input_throughput: float = 0.0
|
34
45
|
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
@@ -50,14 +61,24 @@ class SchedulerMetricsMixin:
|
|
50
61
|
labels["dp_rank"] = dp_rank
|
51
62
|
self.metrics_collector = SchedulerMetricsCollector(labels=labels)
|
52
63
|
|
53
|
-
def
|
64
|
+
def init_dp_balance(self: Scheduler, dp_balance_meta: Optional[DPBalanceMeta]):
|
65
|
+
self.balance_meta = dp_balance_meta
|
66
|
+
if (
|
67
|
+
self.server_args.enable_dp_attention
|
68
|
+
and self.server_args.load_balance_method == "minimum_tokens"
|
69
|
+
):
|
70
|
+
assert dp_balance_meta is not None
|
71
|
+
|
72
|
+
self.recv_dp_balance_id_this_term = []
|
73
|
+
|
74
|
+
def init_kv_events(self: Scheduler, kv_events_config: Optional[str]):
|
54
75
|
if self.enable_kv_cache_events:
|
55
76
|
self.kv_event_publisher = EventPublisherFactory.create(
|
56
77
|
kv_events_config, self.attn_dp_rank
|
57
78
|
)
|
58
79
|
|
59
80
|
def log_prefill_stats(
|
60
|
-
self,
|
81
|
+
self: Scheduler,
|
61
82
|
adder: PrefillAdder,
|
62
83
|
can_run_list: List[Req],
|
63
84
|
running_bs: int,
|
@@ -138,7 +159,7 @@ class SchedulerMetricsMixin:
|
|
138
159
|
self._publish_kv_events()
|
139
160
|
|
140
161
|
def log_decode_stats(
|
141
|
-
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
|
162
|
+
self: Scheduler, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
|
142
163
|
):
|
143
164
|
batch = running_batch or self.running_batch
|
144
165
|
|
@@ -220,7 +241,7 @@ class SchedulerMetricsMixin:
|
|
220
241
|
self._emit_kv_metrics()
|
221
242
|
self._publish_kv_events()
|
222
243
|
|
223
|
-
def _emit_kv_metrics(self):
|
244
|
+
def _emit_kv_metrics(self: Scheduler):
|
224
245
|
kv_metrics = KvMetrics()
|
225
246
|
kv_metrics.request_active_slots = self.stats.num_running_reqs
|
226
247
|
kv_metrics.request_total_slots = self.max_running_requests
|
@@ -236,9 +257,94 @@ class SchedulerMetricsMixin:
|
|
236
257
|
if not self.send_metrics_from_scheduler.closed:
|
237
258
|
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
|
238
259
|
|
239
|
-
def _publish_kv_events(self):
|
260
|
+
def _publish_kv_events(self: Scheduler):
|
240
261
|
if self.enable_kv_cache_events:
|
241
262
|
events = self.tree_cache.take_events()
|
242
263
|
if events:
|
243
264
|
batch = KVEventBatch(ts=time.time(), events=events)
|
244
265
|
self.kv_event_publisher.publish(batch)
|
266
|
+
|
267
|
+
def maybe_update_dp_balance_data(
|
268
|
+
self: Scheduler, recv_req: TokenizedGenerateReqInput
|
269
|
+
):
|
270
|
+
if (
|
271
|
+
self.server_args.enable_dp_attention
|
272
|
+
and self.server_args.load_balance_method == "minimum_tokens"
|
273
|
+
):
|
274
|
+
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
|
275
|
+
|
276
|
+
def maybe_handle_dp_balance_data(self: Scheduler):
|
277
|
+
if (
|
278
|
+
self.server_args.load_balance_method == "minimum_tokens"
|
279
|
+
and self.forward_ct % 40 == 0
|
280
|
+
):
|
281
|
+
holding_tokens = self.get_load()
|
282
|
+
|
283
|
+
new_recv_dp_balance_id_list, holding_token_list = (
|
284
|
+
self.gather_dp_balance_info(holding_tokens)
|
285
|
+
)
|
286
|
+
|
287
|
+
self.recv_dp_balance_id_this_term.clear()
|
288
|
+
if self.tp_rank == 0: # only first worker write info
|
289
|
+
self.write_shared_dp_balance_info(
|
290
|
+
new_recv_dp_balance_id_list, holding_token_list
|
291
|
+
)
|
292
|
+
|
293
|
+
def gather_dp_balance_info(
|
294
|
+
self: Scheduler, holding_tokens_list
|
295
|
+
) -> Union[None, List[List[int]]]:
|
296
|
+
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
|
297
|
+
recv_list = self.recv_dp_balance_id_this_term
|
298
|
+
assert len(recv_list) <= 511, (
|
299
|
+
"The number of requests received this round is too large. "
|
300
|
+
"Please increase gather_tensor_size and onfly_info_size."
|
301
|
+
)
|
302
|
+
# The maximum size of the tensor used for gathering data from all workers.
|
303
|
+
gather_tensor_size = 512
|
304
|
+
|
305
|
+
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
|
306
|
+
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
|
307
|
+
recv_tensor[0] = holding_tokens_list
|
308
|
+
recv_tensor[1] = len(recv_list) # The first element is the length of the list.
|
309
|
+
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(recv_list, dtype=torch.int32)
|
310
|
+
|
311
|
+
if self.tp_rank == 0:
|
312
|
+
gathered_list = [
|
313
|
+
torch.zeros(gather_tensor_size, dtype=torch.int32)
|
314
|
+
for _ in range(self.balance_meta.num_workers)
|
315
|
+
]
|
316
|
+
else:
|
317
|
+
gathered_list = None
|
318
|
+
|
319
|
+
torch.distributed.gather(recv_tensor, gathered_list, group=self.tp_cpu_group)
|
320
|
+
|
321
|
+
gathered_id_list_per_worker = None
|
322
|
+
if self.tp_rank == 0:
|
323
|
+
gathered_id_list_per_worker = []
|
324
|
+
holding_tokens_list = []
|
325
|
+
for tensor in gathered_list:
|
326
|
+
holding_tokens_list.append(tensor[0].item())
|
327
|
+
list_length = tensor[1].item()
|
328
|
+
gathered_id_list_per_worker.append(tensor[2 : list_length + 2].tolist())
|
329
|
+
|
330
|
+
return gathered_id_list_per_worker, holding_tokens_list
|
331
|
+
|
332
|
+
def write_shared_dp_balance_info(self: Scheduler, new_recv_rid_lists, local_tokens):
|
333
|
+
meta = self.balance_meta
|
334
|
+
|
335
|
+
with meta.mutex:
|
336
|
+
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
|
337
|
+
assert len(new_recv_rid_lists) == len(onfly_list), "num_worker not equal"
|
338
|
+
# 1.Check if the rid received by each worker this round is present in onfly.
|
339
|
+
# If it is, remove the corresponding onfly item.
|
340
|
+
worker_id = 0
|
341
|
+
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
|
342
|
+
for new_recv_rid in new_recv_rids:
|
343
|
+
assert (
|
344
|
+
new_recv_rid in on_fly_reqs
|
345
|
+
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
|
346
|
+
del on_fly_reqs[new_recv_rid]
|
347
|
+
worker_id += 1
|
348
|
+
# 2. Atomically write local_tokens and onfly into shm under the mutex
|
349
|
+
meta.set_shared_onfly_info(onfly_list)
|
350
|
+
meta.set_shared_local_tokens(local_tokens)
|
@@ -24,20 +24,20 @@ import os
|
|
24
24
|
import re
|
25
25
|
from typing import Optional
|
26
26
|
|
27
|
-
from sglang.srt.code_completion_parser import (
|
27
|
+
from sglang.srt.parser.code_completion_parser import (
|
28
28
|
CompletionTemplate,
|
29
29
|
FimPosition,
|
30
30
|
completion_template_exists,
|
31
31
|
register_completion_template,
|
32
32
|
)
|
33
|
-
from sglang.srt.conversation import (
|
33
|
+
from sglang.srt.parser.conversation import (
|
34
34
|
Conversation,
|
35
35
|
SeparatorStyle,
|
36
36
|
chat_template_exists,
|
37
37
|
get_conv_template_by_model_path,
|
38
38
|
register_conv_template,
|
39
39
|
)
|
40
|
-
from sglang.srt.jinja_template_utils import detect_jinja_template_content_format
|
40
|
+
from sglang.srt.parser.jinja_template_utils import detect_jinja_template_content_format
|
41
41
|
|
42
42
|
logger = logging.getLogger(__name__)
|
43
43
|
|
@@ -329,6 +329,7 @@ class TokenizerManager:
|
|
329
329
|
# Metrics
|
330
330
|
if self.enable_metrics:
|
331
331
|
self.metrics_collector = TokenizerMetricsCollector(
|
332
|
+
server_args=server_args,
|
332
333
|
labels={
|
333
334
|
"model_name": self.server_args.served_model_name,
|
334
335
|
# TODO: Add lora name/path in the future,
|
@@ -283,7 +283,7 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
283
283
|
self.swa_attn_allocator.clear()
|
284
284
|
self.full_attn_allocator.clear()
|
285
285
|
self.full_to_swa_index_mapping.fill_(0)
|
286
|
-
self.
|
286
|
+
self.is_not_in_free_group = True
|
287
287
|
self.free_group = []
|
288
288
|
|
289
289
|
|
@@ -27,6 +27,7 @@ class HiCacheStorageConfig:
|
|
27
27
|
tp_rank: int
|
28
28
|
tp_size: int
|
29
29
|
is_mla_model: bool
|
30
|
+
is_page_first_layout: bool
|
30
31
|
model_name: Optional[str]
|
31
32
|
extra_config: Optional[dict] = None
|
32
33
|
|
@@ -135,18 +136,24 @@ class HiCacheFile(HiCacheStorage):
|
|
135
136
|
):
|
136
137
|
self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
|
137
138
|
|
138
|
-
tp_rank, tp_size,
|
139
|
+
tp_rank, tp_size, model_name, is_mla_model = (
|
139
140
|
storage_config.tp_rank,
|
140
141
|
storage_config.tp_size,
|
142
|
+
storage_config.model_name,
|
141
143
|
storage_config.is_mla_model,
|
142
144
|
)
|
143
|
-
|
145
|
+
model_name = "-".join(model_name.split("/")) if model_name else ""
|
146
|
+
if is_mla_model:
|
147
|
+
self.config_suffix = f"_{model_name}"
|
148
|
+
else:
|
149
|
+
self.config_suffix = f"_{model_name}_{tp_rank}_{tp_size}"
|
150
|
+
|
144
151
|
if not os.path.exists(self.file_path) and tp_rank == 0:
|
145
152
|
os.makedirs(self.file_path)
|
146
153
|
logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
|
147
154
|
|
148
155
|
def _get_suffixed_key(self, key: str) -> str:
|
149
|
-
return key + self.
|
156
|
+
return key + self.config_suffix
|
150
157
|
|
151
158
|
def get(
|
152
159
|
self,
|
@@ -157,13 +164,11 @@ class HiCacheFile(HiCacheStorage):
|
|
157
164
|
key = self._get_suffixed_key(key)
|
158
165
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
159
166
|
try:
|
160
|
-
|
161
|
-
with open(tensor_path, "rb") as f:
|
162
|
-
target_location.
|
163
|
-
|
164
|
-
|
165
|
-
.untyped_storage()
|
166
|
-
)
|
167
|
+
expected = target_location.numel() * target_location.element_size()
|
168
|
+
with open(tensor_path, "rb", buffering=0) as f:
|
169
|
+
buf = memoryview(target_location.view(torch.uint8).contiguous().numpy())
|
170
|
+
if f.readinto(buf) != expected:
|
171
|
+
raise IOError(f"Short read for {key}")
|
167
172
|
return target_location
|
168
173
|
except FileNotFoundError:
|
169
174
|
logger.warning(f"Failed to fetch {key} from HiCacheFile storage.")
|
@@ -468,9 +468,9 @@ class HiRadixCache(RadixCache):
|
|
468
468
|
|
469
469
|
# todo: more policies for prefetch progress such as timeout
|
470
470
|
# the current policy is to prefetch with best effort and terminate when queuing is over
|
471
|
-
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch
|
471
|
+
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
|
472
472
|
req_id
|
473
|
-
|
473
|
+
]
|
474
474
|
|
475
475
|
if operation.host_indices is None:
|
476
476
|
# prefetch has not been issued due to insufficient host memory
|
@@ -512,6 +512,7 @@ class HiRadixCache(RadixCache):
|
|
512
512
|
host_indices[min_completed_tokens:completed_tokens]
|
513
513
|
)
|
514
514
|
last_host_node.release_host()
|
515
|
+
del self.ongoing_prefetch[req_id]
|
515
516
|
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
516
517
|
|
517
518
|
return True
|
@@ -775,9 +776,7 @@ class HiRadixCache(RadixCache):
|
|
775
776
|
if rid not in self.ongoing_prefetch:
|
776
777
|
return
|
777
778
|
|
778
|
-
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch
|
779
|
-
rid
|
780
|
-
)
|
779
|
+
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[rid]
|
781
780
|
if operation.host_indices is None:
|
782
781
|
return
|
783
782
|
|
@@ -785,5 +784,6 @@ class HiRadixCache(RadixCache):
|
|
785
784
|
if self.tp_world_size > 1:
|
786
785
|
torch.distributed.barrier(group=self.tp_group)
|
787
786
|
last_host_node.release_host()
|
787
|
+
del self.ongoing_prefetch[rid]
|
788
788
|
self.cache_controller.append_host_mem_release(host_indices[:completed_tokens])
|
789
789
|
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
@@ -500,20 +500,23 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
500
500
|
element_size_list = [element_size] * len(key_list)
|
501
501
|
return key_list, ptr_list, element_size_list
|
502
502
|
|
503
|
-
def get_buffer_with_hash(self, keys, indices):
|
503
|
+
def get_buffer_with_hash(self, keys, indices=None):
|
504
504
|
assert self.layout == "page_first"
|
505
|
-
assert len(keys) == (len(indices) // self.page_size)
|
505
|
+
assert indices is None or (len(keys) == (len(indices) // self.page_size))
|
506
506
|
|
507
507
|
key_list = []
|
508
508
|
buf_list = []
|
509
509
|
|
510
|
-
for
|
510
|
+
for i in range(len(keys)):
|
511
|
+
key = keys[i]
|
511
512
|
key_list.append(f"{key}-k")
|
512
|
-
buf_list.append(self.k_buffer[i : i + self.page_size])
|
513
513
|
key_list.append(f"{key}-v")
|
514
|
-
|
514
|
+
if indices is not None:
|
515
|
+
index = indices[i * self.page_size]
|
516
|
+
buf_list.append(self.k_buffer[index : index + self.page_size])
|
517
|
+
buf_list.append(self.v_buffer[index : index + self.page_size])
|
515
518
|
|
516
|
-
return key_list, buf_list
|
519
|
+
return key_list, buf_list, 2
|
517
520
|
|
518
521
|
|
519
522
|
class MLATokenToKVPoolHost(HostKVCache):
|
@@ -728,13 +731,15 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
728
731
|
element_size_list = [element_size] * len(key_list)
|
729
732
|
return key_list, ptr_list, element_size_list
|
730
733
|
|
731
|
-
def get_buffer_with_hash(self, keys, indices):
|
734
|
+
def get_buffer_with_hash(self, keys, indices=None):
|
732
735
|
assert self.layout == "page_first"
|
733
|
-
assert len(keys) == (len(indices) // self.page_size)
|
736
|
+
assert indices is None or (len(keys) == (len(indices) // self.page_size))
|
734
737
|
|
735
738
|
buf_list = []
|
736
739
|
|
737
|
-
|
738
|
-
|
740
|
+
if indices is not None:
|
741
|
+
for i in range(len(keys)):
|
742
|
+
index = indices[i * self.page_size]
|
743
|
+
buf_list.append(self.kv_buffer[index : index + self.page_size])
|
739
744
|
|
740
|
-
return keys, buf_list
|
745
|
+
return keys, buf_list, 1
|
@@ -128,6 +128,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
128
128
|
dtype: torch.dtype,
|
129
129
|
metadata_client: Hf3fsMetadataInterface,
|
130
130
|
is_mla_model: bool = False,
|
131
|
+
is_page_first_layout: bool = False,
|
131
132
|
):
|
132
133
|
self.rank = rank
|
133
134
|
self.file_path = file_path
|
@@ -138,6 +139,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
138
139
|
self.dtype = dtype
|
139
140
|
self.metadata_client = metadata_client
|
140
141
|
self.is_mla_model = is_mla_model
|
142
|
+
self.is_page_first_layout = is_page_first_layout
|
141
143
|
self.numel = self.bytes_per_page // self.dtype.itemsize
|
142
144
|
self.num_pages = self.file_size // self.bytes_per_page
|
143
145
|
self.skip_backup = False
|
@@ -193,9 +195,13 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
193
195
|
)
|
194
196
|
|
195
197
|
if storage_config is not None:
|
196
|
-
rank, is_mla_model =
|
198
|
+
rank, is_mla_model, is_page_first_layout = (
|
199
|
+
storage_config.tp_rank,
|
200
|
+
storage_config.is_mla_model,
|
201
|
+
storage_config.is_page_first_layout,
|
202
|
+
)
|
197
203
|
else:
|
198
|
-
rank, is_mla_model = 0, False
|
204
|
+
rank, is_mla_model, is_page_first_layout = 0, False, False
|
199
205
|
|
200
206
|
mla_unsupported_msg = f"MLA model is not supported without global metadata server, please refer to https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md"
|
201
207
|
|
@@ -213,6 +219,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
213
219
|
entries=8,
|
214
220
|
dtype=dtype,
|
215
221
|
metadata_client=Hf3fsLocalMetadataClient(),
|
222
|
+
is_page_first_layout=is_page_first_layout,
|
216
223
|
)
|
217
224
|
|
218
225
|
try:
|
@@ -261,6 +268,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
261
268
|
dtype=dtype,
|
262
269
|
metadata_client=metadata_client,
|
263
270
|
is_mla_model=is_mla_model,
|
271
|
+
is_page_first_layout=is_page_first_layout,
|
264
272
|
)
|
265
273
|
|
266
274
|
def get(
|