sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/configs/model_config.py +15 -6
- sglang/srt/layers/attention/flashinfer_backend.py +17 -3
- sglang/srt/layers/linear.py +36 -98
- sglang/srt/layers/moe/fused_moe_triton/layer.py +37 -9
- sglang/srt/layers/moe/topk.py +4 -2
- sglang/srt/layers/parameter.py +24 -16
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/fp8.py +106 -52
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/int8_kernel.py +54 -0
- sglang/srt/layers/quantization/modelopt_quant.py +1 -1
- sglang/srt/layers/quantization/w8a8_int8.py +117 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/vocab_parallel_embedding.py +15 -2
- sglang/srt/managers/configure_logging.py +43 -0
- sglang/srt/managers/detokenizer_manager.py +0 -2
- sglang/srt/managers/io_struct.py +29 -13
- sglang/srt/managers/scheduler.py +48 -9
- sglang/srt/managers/tokenizer_manager.py +109 -49
- sglang/srt/mem_cache/memory_pool.py +107 -52
- sglang/srt/metrics/collector.py +10 -5
- sglang/srt/model_executor/model_runner.py +43 -6
- sglang/srt/models/llama.py +37 -2
- sglang/srt/models/qwen2.py +11 -0
- sglang/srt/models/qwen2_eagle.py +131 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
- sglang/srt/sampling/sampling_batch_info.py +14 -5
- sglang/srt/sampling/sampling_params.py +1 -1
- sglang/srt/server.py +114 -61
- sglang/srt/server_args.py +27 -18
- sglang/srt/speculative/eagle_worker.py +1 -0
- sglang/srt/torch_memory_saver_adapter.py +59 -0
- sglang/srt/utils.py +29 -0
- sglang/version.py +1 -1
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/METADATA +12 -10
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/RECORD +39 -34
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -13,6 +13,7 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
15
15
|
|
16
|
+
import faulthandler
|
16
17
|
import logging
|
17
18
|
import os
|
18
19
|
import signal
|
@@ -46,6 +47,10 @@ from sglang.srt.managers.io_struct import (
|
|
46
47
|
OpenSessionReqInput,
|
47
48
|
OpenSessionReqOutput,
|
48
49
|
ProfileReq,
|
50
|
+
ReleaseMemoryOccupationReqInput,
|
51
|
+
ReleaseMemoryOccupationReqOutput,
|
52
|
+
ResumeMemoryOccupationReqInput,
|
53
|
+
ResumeMemoryOccupationReqOutput,
|
49
54
|
TokenizedEmbeddingReqInput,
|
50
55
|
TokenizedGenerateReqInput,
|
51
56
|
UpdateWeightFromDiskReqInput,
|
@@ -77,6 +82,7 @@ from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerSta
|
|
77
82
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
78
83
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
79
84
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
85
|
+
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
80
86
|
from sglang.srt.utils import (
|
81
87
|
broadcast_pyobj,
|
82
88
|
configure_logger,
|
@@ -356,6 +362,10 @@ class Scheduler:
|
|
356
362
|
t.start()
|
357
363
|
self.parent_process = psutil.Process().parent()
|
358
364
|
|
365
|
+
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
366
|
+
enable=server_args.enable_memory_saver
|
367
|
+
)
|
368
|
+
|
359
369
|
# Init profiler
|
360
370
|
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
|
361
371
|
self.profiler = None
|
@@ -399,6 +409,8 @@ class Scheduler:
|
|
399
409
|
self.watchdog_last_time = time.time()
|
400
410
|
time.sleep(self.watchdog_timeout / 2)
|
401
411
|
|
412
|
+
# Wait sometimes so that the parent process can print the error.
|
413
|
+
time.sleep(5)
|
402
414
|
self.parent_process.send_signal(signal.SIGQUIT)
|
403
415
|
|
404
416
|
@torch.no_grad()
|
@@ -516,6 +528,12 @@ class Scheduler:
|
|
516
528
|
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
517
529
|
parameter = self.get_weights_by_name(recv_req)
|
518
530
|
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
531
|
+
elif isinstance(recv_req, ReleaseMemoryOccupationReqInput):
|
532
|
+
self.release_memory_occupation()
|
533
|
+
self.send_to_tokenizer.send_pyobj(ReleaseMemoryOccupationReqOutput())
|
534
|
+
elif isinstance(recv_req, ResumeMemoryOccupationReqInput):
|
535
|
+
self.resume_memory_occupation()
|
536
|
+
self.send_to_tokenizer.send_pyobj(ResumeMemoryOccupationReqOutput())
|
519
537
|
elif isinstance(recv_req, ProfileReq):
|
520
538
|
if recv_req == ProfileReq.START_PROFILE:
|
521
539
|
self.start_profile()
|
@@ -1253,7 +1271,6 @@ class Scheduler:
|
|
1253
1271
|
decode_ids_list = []
|
1254
1272
|
read_offsets = []
|
1255
1273
|
output_ids = []
|
1256
|
-
origin_input_ids = []
|
1257
1274
|
|
1258
1275
|
skip_special_tokens = []
|
1259
1276
|
spaces_between_special_tokens = []
|
@@ -1305,14 +1322,8 @@ class Scheduler:
|
|
1305
1322
|
decode_ids, read_offset = req.init_incremental_detokenize()
|
1306
1323
|
decode_ids_list.append(decode_ids)
|
1307
1324
|
read_offsets.append(read_offset)
|
1308
|
-
if self.skip_tokenizer_init
|
1325
|
+
if self.skip_tokenizer_init:
|
1309
1326
|
output_ids.append(req.output_ids)
|
1310
|
-
else:
|
1311
|
-
output_ids = None
|
1312
|
-
if self.server_args.return_token_ids:
|
1313
|
-
origin_input_ids.append(req.origin_input_ids)
|
1314
|
-
else:
|
1315
|
-
origin_input_ids = None
|
1316
1327
|
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
|
1317
1328
|
spaces_between_special_tokens.append(
|
1318
1329
|
req.sampling_params.spaces_between_special_tokens
|
@@ -1344,7 +1355,6 @@ class Scheduler:
|
|
1344
1355
|
decoded_texts,
|
1345
1356
|
decode_ids_list,
|
1346
1357
|
read_offsets,
|
1347
|
-
origin_input_ids,
|
1348
1358
|
output_ids,
|
1349
1359
|
skip_special_tokens,
|
1350
1360
|
spaces_between_special_tokens,
|
@@ -1543,6 +1553,20 @@ class Scheduler:
|
|
1543
1553
|
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
1544
1554
|
return parameter
|
1545
1555
|
|
1556
|
+
def release_memory_occupation(self):
|
1557
|
+
self.stashed_model_static_state = _export_static_state(
|
1558
|
+
self.tp_worker.worker.model_runner.model
|
1559
|
+
)
|
1560
|
+
self.memory_saver_adapter.pause()
|
1561
|
+
self.flush_cache()
|
1562
|
+
|
1563
|
+
def resume_memory_occupation(self):
|
1564
|
+
self.memory_saver_adapter.resume()
|
1565
|
+
_import_static_state(
|
1566
|
+
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
|
1567
|
+
)
|
1568
|
+
del self.stashed_model_static_state
|
1569
|
+
|
1546
1570
|
def start_profile(self) -> None:
|
1547
1571
|
if self.profiler is None:
|
1548
1572
|
raise RuntimeError("Profiler is not enabled.")
|
@@ -1581,6 +1605,20 @@ class Scheduler:
|
|
1581
1605
|
del self.sessions[session_id]
|
1582
1606
|
|
1583
1607
|
|
1608
|
+
def _export_static_state(model):
|
1609
|
+
return dict(
|
1610
|
+
buffers=[
|
1611
|
+
(name, buffer.detach().clone()) for name, buffer in model.named_buffers()
|
1612
|
+
]
|
1613
|
+
)
|
1614
|
+
|
1615
|
+
|
1616
|
+
def _import_static_state(model, static_params):
|
1617
|
+
self_named_buffers = dict(model.named_buffers())
|
1618
|
+
for name, tensor in static_params["buffers"]:
|
1619
|
+
self_named_buffers[name][...] = tensor
|
1620
|
+
|
1621
|
+
|
1584
1622
|
def run_scheduler_process(
|
1585
1623
|
server_args: ServerArgs,
|
1586
1624
|
port_args: PortArgs,
|
@@ -1590,6 +1628,7 @@ def run_scheduler_process(
|
|
1590
1628
|
pipe_writer,
|
1591
1629
|
):
|
1592
1630
|
setproctitle.setproctitle("sglang::scheduler")
|
1631
|
+
faulthandler.enable()
|
1593
1632
|
|
1594
1633
|
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
1595
1634
|
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
@@ -18,10 +18,12 @@ import copy
|
|
18
18
|
import dataclasses
|
19
19
|
import logging
|
20
20
|
import os
|
21
|
+
import pickle
|
21
22
|
import signal
|
22
23
|
import sys
|
23
24
|
import time
|
24
25
|
import uuid
|
26
|
+
from datetime import datetime
|
25
27
|
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
|
26
28
|
|
27
29
|
import fastapi
|
@@ -43,6 +45,7 @@ from sglang.srt.managers.io_struct import (
|
|
43
45
|
BatchStrOut,
|
44
46
|
BatchTokenIDOut,
|
45
47
|
CloseSessionReqInput,
|
48
|
+
ConfigureLoggingReq,
|
46
49
|
EmbeddingReqInput,
|
47
50
|
FlushCacheReq,
|
48
51
|
GenerateReqInput,
|
@@ -53,6 +56,10 @@ from sglang.srt.managers.io_struct import (
|
|
53
56
|
OpenSessionReqInput,
|
54
57
|
OpenSessionReqOutput,
|
55
58
|
ProfileReq,
|
59
|
+
ReleaseMemoryOccupationReqInput,
|
60
|
+
ReleaseMemoryOccupationReqOutput,
|
61
|
+
ResumeMemoryOccupationReqInput,
|
62
|
+
ResumeMemoryOccupationReqOutput,
|
56
63
|
SessionParams,
|
57
64
|
TokenizedEmbeddingReqInput,
|
58
65
|
TokenizedGenerateReqInput,
|
@@ -105,6 +112,7 @@ class TokenizerManager:
|
|
105
112
|
# Parse args
|
106
113
|
self.server_args = server_args
|
107
114
|
self.enable_metrics = server_args.enable_metrics
|
115
|
+
self.log_requests = server_args.log_requests
|
108
116
|
|
109
117
|
# Init inter-process communication
|
110
118
|
context = zmq.asyncio.Context(2)
|
@@ -163,6 +171,9 @@ class TokenizerManager:
|
|
163
171
|
# Store states
|
164
172
|
self.to_create_loop = True
|
165
173
|
self.rid_to_state: Dict[str, ReqState] = {}
|
174
|
+
self.dump_requests_folder = "" # By default do not dump
|
175
|
+
self.dump_requests_threshold = 1000
|
176
|
+
self.dump_request_list: List[Tuple] = []
|
166
177
|
|
167
178
|
# The event to notify the weight sync is finished.
|
168
179
|
self.model_update_lock = RWLock()
|
@@ -188,6 +199,12 @@ class TokenizerManager:
|
|
188
199
|
self.get_weights_by_name_communicator = _Communicator(
|
189
200
|
self.send_to_scheduler, server_args.dp_size
|
190
201
|
)
|
202
|
+
self.release_memory_occupation_communicator = _Communicator(
|
203
|
+
self.send_to_scheduler, server_args.dp_size
|
204
|
+
)
|
205
|
+
self.resume_memory_occupation_communicator = _Communicator(
|
206
|
+
self.send_to_scheduler, server_args.dp_size
|
207
|
+
)
|
191
208
|
|
192
209
|
# Metrics
|
193
210
|
if self.enable_metrics:
|
@@ -215,7 +232,7 @@ class TokenizerManager:
|
|
215
232
|
|
216
233
|
obj.normalize_batch_and_arguments()
|
217
234
|
|
218
|
-
if self.
|
235
|
+
if self.log_requests:
|
219
236
|
logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")
|
220
237
|
|
221
238
|
async with self.model_update_lock.reader_lock:
|
@@ -336,7 +353,7 @@ class TokenizerManager:
|
|
336
353
|
|
337
354
|
state.out_list = []
|
338
355
|
if state.finished:
|
339
|
-
if self.
|
356
|
+
if self.log_requests:
|
340
357
|
msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
|
341
358
|
logger.info(msg)
|
342
359
|
del self.rid_to_state[obj.rid]
|
@@ -548,6 +565,22 @@ class TokenizerManager:
|
|
548
565
|
else:
|
549
566
|
return all_parameters
|
550
567
|
|
568
|
+
async def release_memory_occupation(
|
569
|
+
self,
|
570
|
+
obj: ReleaseMemoryOccupationReqInput,
|
571
|
+
request: Optional[fastapi.Request] = None,
|
572
|
+
):
|
573
|
+
self.auto_create_handle_loop()
|
574
|
+
await self.release_memory_occupation_communicator(obj)
|
575
|
+
|
576
|
+
async def resume_memory_occupation(
|
577
|
+
self,
|
578
|
+
obj: ResumeMemoryOccupationReqInput,
|
579
|
+
request: Optional[fastapi.Request] = None,
|
580
|
+
):
|
581
|
+
self.auto_create_handle_loop()
|
582
|
+
await self.resume_memory_occupation_communicator(obj)
|
583
|
+
|
551
584
|
async def open_session(
|
552
585
|
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
553
586
|
):
|
@@ -571,6 +604,15 @@ class TokenizerManager:
|
|
571
604
|
assert not self.to_create_loop, "close session should not be the first request"
|
572
605
|
await self.send_to_scheduler.send_pyobj(obj)
|
573
606
|
|
607
|
+
def configure_logging(self, obj: ConfigureLoggingReq):
|
608
|
+
if obj.log_requests is not None:
|
609
|
+
self.log_requests = obj.log_requests
|
610
|
+
if obj.dump_requests_folder is not None:
|
611
|
+
self.dump_requests_folder = obj.dump_requests_folder
|
612
|
+
if obj.dump_requests_threshold is not None:
|
613
|
+
self.dump_requests_threshold = obj.dump_requests_threshold
|
614
|
+
logging.info(f"Config logging: {obj=}")
|
615
|
+
|
574
616
|
def create_abort_task(self, obj: GenerateReqInput):
|
575
617
|
# Abort the request if the client is disconnected.
|
576
618
|
async def abort_request():
|
@@ -601,7 +643,7 @@ class TokenizerManager:
|
|
601
643
|
while not self.gracefully_exit:
|
602
644
|
await asyncio.sleep(5)
|
603
645
|
|
604
|
-
#
|
646
|
+
# Drain requests
|
605
647
|
while True:
|
606
648
|
remain_num_req = len(self.rid_to_state)
|
607
649
|
logger.info(
|
@@ -627,6 +669,8 @@ class TokenizerManager:
|
|
627
669
|
UpdateWeightsFromDistributedReqOutput,
|
628
670
|
GetWeightsByNameReqOutput,
|
629
671
|
InitWeightsUpdateGroupReqOutput,
|
672
|
+
ReleaseMemoryOccupationReqOutput,
|
673
|
+
ResumeMemoryOccupationReqOutput,
|
630
674
|
] = await self.recv_from_detokenizer.recv_pyobj()
|
631
675
|
|
632
676
|
if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
|
@@ -663,13 +707,6 @@ class TokenizerManager:
|
|
663
707
|
"text": recv_obj.output_strs[i],
|
664
708
|
"meta_info": meta_info,
|
665
709
|
}
|
666
|
-
if self.server_args.return_token_ids:
|
667
|
-
out_dict.update(
|
668
|
-
{
|
669
|
-
"input_ids": recv_obj.origin_input_ids[i],
|
670
|
-
"output_ids": recv_obj.output_ids[i],
|
671
|
-
}
|
672
|
-
)
|
673
710
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
674
711
|
out_dict = {
|
675
712
|
"token_ids": recv_obj.output_ids[i],
|
@@ -686,45 +723,9 @@ class TokenizerManager:
|
|
686
723
|
state.event.set()
|
687
724
|
|
688
725
|
if self.enable_metrics:
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
else 0
|
693
|
-
)
|
694
|
-
|
695
|
-
if state.first_token_time is None:
|
696
|
-
state.first_token_time = time.time()
|
697
|
-
self.metrics_collector.observe_time_to_first_token(
|
698
|
-
state.first_token_time - state.created_time
|
699
|
-
)
|
700
|
-
else:
|
701
|
-
if completion_tokens >= 2:
|
702
|
-
# Compute time_per_output_token for the streaming case
|
703
|
-
self.metrics_collector.observe_time_per_output_token(
|
704
|
-
(time.time() - state.first_token_time)
|
705
|
-
/ (completion_tokens - 1)
|
706
|
-
)
|
707
|
-
|
708
|
-
if state.finished:
|
709
|
-
self.metrics_collector.inc_prompt_tokens(
|
710
|
-
recv_obj.prompt_tokens[i]
|
711
|
-
)
|
712
|
-
self.metrics_collector.inc_generation_tokens(
|
713
|
-
completion_tokens
|
714
|
-
)
|
715
|
-
self.metrics_collector.observe_e2e_request_latency(
|
716
|
-
time.time() - state.created_time
|
717
|
-
)
|
718
|
-
# Compute time_per_output_token for the non-streaming case
|
719
|
-
if (
|
720
|
-
hasattr(state.obj, "stream")
|
721
|
-
and not state.obj.stream
|
722
|
-
and completion_tokens >= 1
|
723
|
-
):
|
724
|
-
self.metrics_collector.observe_time_per_output_token(
|
725
|
-
(time.time() - state.created_time)
|
726
|
-
/ completion_tokens
|
727
|
-
)
|
726
|
+
self.collect_metrics(state, recv_obj, i)
|
727
|
+
if self.dump_requests_folder and state.finished:
|
728
|
+
self.dump_requests(state, out_dict)
|
728
729
|
elif isinstance(recv_obj, OpenSessionReqOutput):
|
729
730
|
self.session_futures[recv_obj.session_id].set_result(
|
730
731
|
recv_obj.session_id if recv_obj.success else None
|
@@ -754,6 +755,10 @@ class TokenizerManager:
|
|
754
755
|
self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
|
755
756
|
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
756
757
|
self.get_weights_by_name_communicator.handle_recv(recv_obj)
|
758
|
+
elif isinstance(recv_obj, ReleaseMemoryOccupationReqOutput):
|
759
|
+
self.release_memory_occupation_communicator.handle_recv(recv_obj)
|
760
|
+
elif isinstance(recv_obj, ResumeMemoryOccupationReqOutput):
|
761
|
+
self.resume_memory_occupation_communicator.handle_recv(recv_obj)
|
757
762
|
else:
|
758
763
|
raise ValueError(f"Invalid object: {recv_obj=}")
|
759
764
|
|
@@ -827,6 +832,61 @@ class TokenizerManager:
|
|
827
832
|
ret.append(None)
|
828
833
|
return ret
|
829
834
|
|
835
|
+
def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
|
836
|
+
completion_tokens = (
|
837
|
+
recv_obj.completion_tokens[i]
|
838
|
+
if getattr(recv_obj, "completion_tokens", None)
|
839
|
+
else 0
|
840
|
+
)
|
841
|
+
|
842
|
+
if state.first_token_time is None:
|
843
|
+
state.first_token_time = time.time()
|
844
|
+
self.metrics_collector.observe_time_to_first_token(
|
845
|
+
state.first_token_time - state.created_time
|
846
|
+
)
|
847
|
+
else:
|
848
|
+
if completion_tokens >= 2:
|
849
|
+
# Compute time_per_output_token for the streaming case
|
850
|
+
self.metrics_collector.observe_time_per_output_token(
|
851
|
+
(time.time() - state.first_token_time) / (completion_tokens - 1)
|
852
|
+
)
|
853
|
+
|
854
|
+
if state.finished:
|
855
|
+
self.metrics_collector.observe_one_finished_request(
|
856
|
+
recv_obj.prompt_tokens[i], completion_tokens
|
857
|
+
)
|
858
|
+
self.metrics_collector.observe_e2e_request_latency(
|
859
|
+
time.time() - state.created_time
|
860
|
+
)
|
861
|
+
# Compute time_per_output_token for the non-streaming case
|
862
|
+
if (
|
863
|
+
hasattr(state.obj, "stream")
|
864
|
+
and not state.obj.stream
|
865
|
+
and completion_tokens >= 1
|
866
|
+
):
|
867
|
+
self.metrics_collector.observe_time_per_output_token(
|
868
|
+
(time.time() - state.created_time) / completion_tokens
|
869
|
+
)
|
870
|
+
|
871
|
+
def dump_requests(self, state: ReqState, out_dict: dict):
|
872
|
+
self.dump_request_list.append(
|
873
|
+
(state.obj, out_dict, state.created_time, time.time())
|
874
|
+
)
|
875
|
+
|
876
|
+
if len(self.dump_request_list) >= self.dump_requests_threshold:
|
877
|
+
to_dump = self.dump_request_list
|
878
|
+
self.dump_request_list = []
|
879
|
+
|
880
|
+
def background_task():
|
881
|
+
os.makedirs(self.dump_requests_folder, exist_ok=True)
|
882
|
+
current_time = datetime.now()
|
883
|
+
filename = current_time.strftime("%Y-%m-%d_%H-%M-%S") + ".pkl"
|
884
|
+
with open(os.path.join(self.dump_requests_folder, filename), "wb") as f:
|
885
|
+
pickle.dump(to_dump, f)
|
886
|
+
|
887
|
+
# Schedule the task to run in the background without awaiting it
|
888
|
+
asyncio.create_task(asyncio.to_thread(background_task))
|
889
|
+
|
830
890
|
|
831
891
|
class SignalHandler:
|
832
892
|
def __init__(self, tokenizer_manager):
|
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
+
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
17
|
+
|
16
18
|
"""
|
17
19
|
Memory pool.
|
18
20
|
|
@@ -27,6 +29,7 @@ from enum import IntEnum
|
|
27
29
|
from functools import wraps
|
28
30
|
from typing import List, Tuple, Union
|
29
31
|
|
32
|
+
import numpy as np
|
30
33
|
import psutil
|
31
34
|
import torch
|
32
35
|
|
@@ -35,17 +38,31 @@ from sglang.srt.utils import debug_timing, get_compiler_backend
|
|
35
38
|
|
36
39
|
logger = logging.getLogger(__name__)
|
37
40
|
|
41
|
+
GB = 1024 * 1024 * 1024
|
42
|
+
|
38
43
|
|
39
44
|
class ReqToTokenPool:
|
40
45
|
"""A memory pool that maps a request to its token locations."""
|
41
46
|
|
42
|
-
def __init__(
|
47
|
+
def __init__(
|
48
|
+
self,
|
49
|
+
size: int,
|
50
|
+
max_context_len: int,
|
51
|
+
device: str,
|
52
|
+
use_records: bool,
|
53
|
+
enable_memory_saver: bool,
|
54
|
+
):
|
55
|
+
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
56
|
+
enable=enable_memory_saver
|
57
|
+
)
|
58
|
+
|
43
59
|
self.size = size
|
44
60
|
self.max_context_len = max_context_len
|
45
61
|
self.device = device
|
46
|
-
|
47
|
-
|
48
|
-
|
62
|
+
with memory_saver_adapter.region():
|
63
|
+
self.req_to_token = torch.zeros(
|
64
|
+
(size, max_context_len), dtype=torch.int32, device=device
|
65
|
+
)
|
49
66
|
self.free_slots = list(range(size))
|
50
67
|
self.write_records = []
|
51
68
|
self.use_records = use_records
|
@@ -109,8 +126,8 @@ class BaseTokenToKVPool:
|
|
109
126
|
):
|
110
127
|
self.size = size
|
111
128
|
self.dtype = dtype
|
112
|
-
if dtype
|
113
|
-
# NOTE: Store as torch.uint8 because Tensor
|
129
|
+
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
130
|
+
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
114
131
|
self.store_dtype = torch.uint8
|
115
132
|
else:
|
116
133
|
self.store_dtype = dtype
|
@@ -186,37 +203,60 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
186
203
|
head_dim: int,
|
187
204
|
layer_num: int,
|
188
205
|
device: str,
|
206
|
+
enable_memory_saver: bool,
|
189
207
|
):
|
190
208
|
super().__init__(size, dtype, device)
|
209
|
+
|
210
|
+
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
211
|
+
enable=enable_memory_saver
|
212
|
+
)
|
213
|
+
|
191
214
|
self.head_num = head_num
|
192
215
|
self.head_dim = head_dim
|
193
216
|
self.layer_num = layer_num
|
194
217
|
self._create_buffers()
|
195
218
|
|
219
|
+
k_size, v_size = self.get_kv_size_bytes()
|
220
|
+
logger.info(
|
221
|
+
f"KV Cache is allocated. K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB."
|
222
|
+
)
|
223
|
+
|
196
224
|
def _create_buffers(self):
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
(
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
(
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
225
|
+
with self.memory_saver_adapter.region():
|
226
|
+
# [size, head_num, head_dim] for each layer
|
227
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
228
|
+
self.k_buffer = [
|
229
|
+
torch.empty(
|
230
|
+
(self.size + 1, self.head_num, self.head_dim),
|
231
|
+
dtype=self.store_dtype,
|
232
|
+
device=self.device,
|
233
|
+
)
|
234
|
+
for _ in range(self.layer_num)
|
235
|
+
]
|
236
|
+
self.v_buffer = [
|
237
|
+
torch.empty(
|
238
|
+
(self.size + 1, self.head_num, self.head_dim),
|
239
|
+
dtype=self.store_dtype,
|
240
|
+
device=self.device,
|
241
|
+
)
|
242
|
+
for _ in range(self.layer_num)
|
243
|
+
]
|
215
244
|
|
216
245
|
def _clear_buffers(self):
|
217
246
|
del self.k_buffer
|
218
247
|
del self.v_buffer
|
219
248
|
|
249
|
+
def get_kv_size_bytes(self):
|
250
|
+
assert hasattr(self, "k_buffer")
|
251
|
+
assert hasattr(self, "v_buffer")
|
252
|
+
k_size_bytes = 0
|
253
|
+
for k_cache in self.k_buffer:
|
254
|
+
k_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
|
255
|
+
v_size_bytes = 0
|
256
|
+
for v_cache in self.v_buffer:
|
257
|
+
v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
|
258
|
+
return k_size_bytes, v_size_bytes
|
259
|
+
|
220
260
|
# Todo: different memory layout
|
221
261
|
def get_flat_data(self, indices):
|
222
262
|
# prepare a large chunk of contiguous data for efficient transfer
|
@@ -256,11 +296,13 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
256
296
|
loc: torch.Tensor,
|
257
297
|
cache_k: torch.Tensor,
|
258
298
|
cache_v: torch.Tensor,
|
299
|
+
k_scale: float = 1.0,
|
300
|
+
v_scale: float = 1.0,
|
259
301
|
):
|
260
302
|
layer_id = layer.layer_id
|
261
303
|
if cache_k.dtype != self.dtype:
|
262
|
-
cache_k = cache_k.to(self.dtype)
|
263
|
-
cache_v = cache_v.to(self.dtype)
|
304
|
+
cache_k = (cache_k / k_scale).to(self.dtype)
|
305
|
+
cache_v = (cache_v / v_scale).to(self.dtype)
|
264
306
|
if self.store_dtype != self.dtype:
|
265
307
|
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
266
308
|
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
|
@@ -286,19 +328,26 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
|
286
328
|
qk_rope_head_dim: int,
|
287
329
|
layer_num: int,
|
288
330
|
device: str,
|
331
|
+
enable_memory_saver: bool,
|
289
332
|
):
|
290
333
|
super().__init__(size, dtype, device)
|
291
334
|
|
292
335
|
self.kv_lora_rank = kv_lora_rank
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
336
|
+
|
337
|
+
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
338
|
+
enable=enable_memory_saver
|
339
|
+
)
|
340
|
+
|
341
|
+
with memory_saver_adapter.region():
|
342
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
343
|
+
self.kv_buffer = [
|
344
|
+
torch.empty(
|
345
|
+
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
|
346
|
+
dtype=self.store_dtype,
|
347
|
+
device=device,
|
348
|
+
)
|
349
|
+
for _ in range(layer_num)
|
350
|
+
]
|
302
351
|
|
303
352
|
def get_key_buffer(self, layer_id: int):
|
304
353
|
if self.store_dtype != self.dtype:
|
@@ -339,26 +388,32 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
|
339
388
|
layer_num: int,
|
340
389
|
device: str,
|
341
390
|
heavy_channel_num: int,
|
391
|
+
enable_memory_saver: bool,
|
342
392
|
):
|
343
393
|
super().__init__(size, dtype, device)
|
344
394
|
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
for
|
361
|
-
|
395
|
+
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
396
|
+
enable=enable_memory_saver
|
397
|
+
)
|
398
|
+
|
399
|
+
with memory_saver_adapter.region():
|
400
|
+
# [size, head_num, head_dim] for each layer
|
401
|
+
self.k_buffer = [
|
402
|
+
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
|
403
|
+
for _ in range(layer_num)
|
404
|
+
]
|
405
|
+
self.v_buffer = [
|
406
|
+
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
|
407
|
+
for _ in range(layer_num)
|
408
|
+
]
|
409
|
+
|
410
|
+
# [size, head_num, heavy_channel_num] for each layer
|
411
|
+
self.label_buffer = [
|
412
|
+
torch.empty(
|
413
|
+
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
|
414
|
+
)
|
415
|
+
for _ in range(layer_num)
|
416
|
+
]
|
362
417
|
|
363
418
|
def get_key_buffer(self, layer_id: int):
|
364
419
|
return self.k_buffer[layer_id]
|
sglang/srt/metrics/collector.py
CHANGED
@@ -109,6 +109,12 @@ class TokenizerMetricsCollector:
|
|
109
109
|
labelnames=labels.keys(),
|
110
110
|
)
|
111
111
|
|
112
|
+
self.num_requests_total = Counter(
|
113
|
+
name="sglang:num_requests_total",
|
114
|
+
documentation="Number of requests processed.",
|
115
|
+
labelnames=labels.keys(),
|
116
|
+
)
|
117
|
+
|
112
118
|
self.histogram_time_to_first_token = Histogram(
|
113
119
|
name="sglang:time_to_first_token_seconds",
|
114
120
|
documentation="Histogram of time to first token in seconds.",
|
@@ -185,11 +191,10 @@ class TokenizerMetricsCollector:
|
|
185
191
|
# Convenience function for logging to counter.
|
186
192
|
counter.labels(**self.labels).inc(data)
|
187
193
|
|
188
|
-
def
|
189
|
-
self.
|
190
|
-
|
191
|
-
|
192
|
-
self._log_counter(self.generation_tokens_total, value)
|
194
|
+
def observe_one_finished_request(self, prompt_tokens: int, generation_tokens: int):
|
195
|
+
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
|
196
|
+
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
|
197
|
+
self.num_requests_total.labels(**self.labels).inc(1)
|
193
198
|
|
194
199
|
def observe_time_to_first_token(self, value: Union[float, int]):
|
195
200
|
self._log_histogram(self.histogram_time_to_first_token, value)
|