sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. sglang/srt/configs/model_config.py +15 -6
  2. sglang/srt/layers/attention/flashinfer_backend.py +17 -3
  3. sglang/srt/layers/linear.py +36 -98
  4. sglang/srt/layers/moe/fused_moe_triton/layer.py +37 -9
  5. sglang/srt/layers/moe/topk.py +4 -2
  6. sglang/srt/layers/parameter.py +24 -16
  7. sglang/srt/layers/quantization/__init__.py +2 -0
  8. sglang/srt/layers/quantization/fp8.py +106 -52
  9. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  10. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  11. sglang/srt/layers/quantization/modelopt_quant.py +1 -1
  12. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  13. sglang/srt/layers/radix_attention.py +2 -0
  14. sglang/srt/layers/vocab_parallel_embedding.py +15 -2
  15. sglang/srt/managers/configure_logging.py +43 -0
  16. sglang/srt/managers/detokenizer_manager.py +0 -2
  17. sglang/srt/managers/io_struct.py +29 -13
  18. sglang/srt/managers/scheduler.py +48 -9
  19. sglang/srt/managers/tokenizer_manager.py +109 -49
  20. sglang/srt/mem_cache/memory_pool.py +107 -52
  21. sglang/srt/metrics/collector.py +10 -5
  22. sglang/srt/model_executor/model_runner.py +43 -6
  23. sglang/srt/models/llama.py +37 -2
  24. sglang/srt/models/qwen2.py +11 -0
  25. sglang/srt/models/qwen2_eagle.py +131 -0
  26. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  27. sglang/srt/sampling/sampling_batch_info.py +14 -5
  28. sglang/srt/sampling/sampling_params.py +1 -1
  29. sglang/srt/server.py +114 -61
  30. sglang/srt/server_args.py +27 -18
  31. sglang/srt/speculative/eagle_worker.py +1 -0
  32. sglang/srt/torch_memory_saver_adapter.py +59 -0
  33. sglang/srt/utils.py +29 -0
  34. sglang/version.py +1 -1
  35. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/METADATA +12 -10
  36. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/RECORD +39 -34
  37. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/LICENSE +0 -0
  38. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/WHEEL +0 -0
  39. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@
13
13
  # ==============================================================================
14
14
  """A scheduler that manages a tensor parallel GPU worker."""
15
15
 
16
+ import faulthandler
16
17
  import logging
17
18
  import os
18
19
  import signal
@@ -46,6 +47,10 @@ from sglang.srt.managers.io_struct import (
46
47
  OpenSessionReqInput,
47
48
  OpenSessionReqOutput,
48
49
  ProfileReq,
50
+ ReleaseMemoryOccupationReqInput,
51
+ ReleaseMemoryOccupationReqOutput,
52
+ ResumeMemoryOccupationReqInput,
53
+ ResumeMemoryOccupationReqOutput,
49
54
  TokenizedEmbeddingReqInput,
50
55
  TokenizedGenerateReqInput,
51
56
  UpdateWeightFromDiskReqInput,
@@ -77,6 +82,7 @@ from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerSta
77
82
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
78
83
  from sglang.srt.server_args import PortArgs, ServerArgs
79
84
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
85
+ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
80
86
  from sglang.srt.utils import (
81
87
  broadcast_pyobj,
82
88
  configure_logger,
@@ -356,6 +362,10 @@ class Scheduler:
356
362
  t.start()
357
363
  self.parent_process = psutil.Process().parent()
358
364
 
365
+ self.memory_saver_adapter = TorchMemorySaverAdapter.create(
366
+ enable=server_args.enable_memory_saver
367
+ )
368
+
359
369
  # Init profiler
360
370
  if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
361
371
  self.profiler = None
@@ -399,6 +409,8 @@ class Scheduler:
399
409
  self.watchdog_last_time = time.time()
400
410
  time.sleep(self.watchdog_timeout / 2)
401
411
 
412
+ # Wait sometimes so that the parent process can print the error.
413
+ time.sleep(5)
402
414
  self.parent_process.send_signal(signal.SIGQUIT)
403
415
 
404
416
  @torch.no_grad()
@@ -516,6 +528,12 @@ class Scheduler:
516
528
  elif isinstance(recv_req, GetWeightsByNameReqInput):
517
529
  parameter = self.get_weights_by_name(recv_req)
518
530
  self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
531
+ elif isinstance(recv_req, ReleaseMemoryOccupationReqInput):
532
+ self.release_memory_occupation()
533
+ self.send_to_tokenizer.send_pyobj(ReleaseMemoryOccupationReqOutput())
534
+ elif isinstance(recv_req, ResumeMemoryOccupationReqInput):
535
+ self.resume_memory_occupation()
536
+ self.send_to_tokenizer.send_pyobj(ResumeMemoryOccupationReqOutput())
519
537
  elif isinstance(recv_req, ProfileReq):
520
538
  if recv_req == ProfileReq.START_PROFILE:
521
539
  self.start_profile()
@@ -1253,7 +1271,6 @@ class Scheduler:
1253
1271
  decode_ids_list = []
1254
1272
  read_offsets = []
1255
1273
  output_ids = []
1256
- origin_input_ids = []
1257
1274
 
1258
1275
  skip_special_tokens = []
1259
1276
  spaces_between_special_tokens = []
@@ -1305,14 +1322,8 @@ class Scheduler:
1305
1322
  decode_ids, read_offset = req.init_incremental_detokenize()
1306
1323
  decode_ids_list.append(decode_ids)
1307
1324
  read_offsets.append(read_offset)
1308
- if self.skip_tokenizer_init or self.server_args.return_token_ids:
1325
+ if self.skip_tokenizer_init:
1309
1326
  output_ids.append(req.output_ids)
1310
- else:
1311
- output_ids = None
1312
- if self.server_args.return_token_ids:
1313
- origin_input_ids.append(req.origin_input_ids)
1314
- else:
1315
- origin_input_ids = None
1316
1327
  skip_special_tokens.append(req.sampling_params.skip_special_tokens)
1317
1328
  spaces_between_special_tokens.append(
1318
1329
  req.sampling_params.spaces_between_special_tokens
@@ -1344,7 +1355,6 @@ class Scheduler:
1344
1355
  decoded_texts,
1345
1356
  decode_ids_list,
1346
1357
  read_offsets,
1347
- origin_input_ids,
1348
1358
  output_ids,
1349
1359
  skip_special_tokens,
1350
1360
  spaces_between_special_tokens,
@@ -1543,6 +1553,20 @@ class Scheduler:
1543
1553
  parameter = self.tp_worker.get_weights_by_name(recv_req)
1544
1554
  return parameter
1545
1555
 
1556
+ def release_memory_occupation(self):
1557
+ self.stashed_model_static_state = _export_static_state(
1558
+ self.tp_worker.worker.model_runner.model
1559
+ )
1560
+ self.memory_saver_adapter.pause()
1561
+ self.flush_cache()
1562
+
1563
+ def resume_memory_occupation(self):
1564
+ self.memory_saver_adapter.resume()
1565
+ _import_static_state(
1566
+ self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
1567
+ )
1568
+ del self.stashed_model_static_state
1569
+
1546
1570
  def start_profile(self) -> None:
1547
1571
  if self.profiler is None:
1548
1572
  raise RuntimeError("Profiler is not enabled.")
@@ -1581,6 +1605,20 @@ class Scheduler:
1581
1605
  del self.sessions[session_id]
1582
1606
 
1583
1607
 
1608
+ def _export_static_state(model):
1609
+ return dict(
1610
+ buffers=[
1611
+ (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
1612
+ ]
1613
+ )
1614
+
1615
+
1616
+ def _import_static_state(model, static_params):
1617
+ self_named_buffers = dict(model.named_buffers())
1618
+ for name, tensor in static_params["buffers"]:
1619
+ self_named_buffers[name][...] = tensor
1620
+
1621
+
1584
1622
  def run_scheduler_process(
1585
1623
  server_args: ServerArgs,
1586
1624
  port_args: PortArgs,
@@ -1590,6 +1628,7 @@ def run_scheduler_process(
1590
1628
  pipe_writer,
1591
1629
  ):
1592
1630
  setproctitle.setproctitle("sglang::scheduler")
1631
+ faulthandler.enable()
1593
1632
 
1594
1633
  # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
1595
1634
  if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
@@ -18,10 +18,12 @@ import copy
18
18
  import dataclasses
19
19
  import logging
20
20
  import os
21
+ import pickle
21
22
  import signal
22
23
  import sys
23
24
  import time
24
25
  import uuid
26
+ from datetime import datetime
25
27
  from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
26
28
 
27
29
  import fastapi
@@ -43,6 +45,7 @@ from sglang.srt.managers.io_struct import (
43
45
  BatchStrOut,
44
46
  BatchTokenIDOut,
45
47
  CloseSessionReqInput,
48
+ ConfigureLoggingReq,
46
49
  EmbeddingReqInput,
47
50
  FlushCacheReq,
48
51
  GenerateReqInput,
@@ -53,6 +56,10 @@ from sglang.srt.managers.io_struct import (
53
56
  OpenSessionReqInput,
54
57
  OpenSessionReqOutput,
55
58
  ProfileReq,
59
+ ReleaseMemoryOccupationReqInput,
60
+ ReleaseMemoryOccupationReqOutput,
61
+ ResumeMemoryOccupationReqInput,
62
+ ResumeMemoryOccupationReqOutput,
56
63
  SessionParams,
57
64
  TokenizedEmbeddingReqInput,
58
65
  TokenizedGenerateReqInput,
@@ -105,6 +112,7 @@ class TokenizerManager:
105
112
  # Parse args
106
113
  self.server_args = server_args
107
114
  self.enable_metrics = server_args.enable_metrics
115
+ self.log_requests = server_args.log_requests
108
116
 
109
117
  # Init inter-process communication
110
118
  context = zmq.asyncio.Context(2)
@@ -163,6 +171,9 @@ class TokenizerManager:
163
171
  # Store states
164
172
  self.to_create_loop = True
165
173
  self.rid_to_state: Dict[str, ReqState] = {}
174
+ self.dump_requests_folder = "" # By default do not dump
175
+ self.dump_requests_threshold = 1000
176
+ self.dump_request_list: List[Tuple] = []
166
177
 
167
178
  # The event to notify the weight sync is finished.
168
179
  self.model_update_lock = RWLock()
@@ -188,6 +199,12 @@ class TokenizerManager:
188
199
  self.get_weights_by_name_communicator = _Communicator(
189
200
  self.send_to_scheduler, server_args.dp_size
190
201
  )
202
+ self.release_memory_occupation_communicator = _Communicator(
203
+ self.send_to_scheduler, server_args.dp_size
204
+ )
205
+ self.resume_memory_occupation_communicator = _Communicator(
206
+ self.send_to_scheduler, server_args.dp_size
207
+ )
191
208
 
192
209
  # Metrics
193
210
  if self.enable_metrics:
@@ -215,7 +232,7 @@ class TokenizerManager:
215
232
 
216
233
  obj.normalize_batch_and_arguments()
217
234
 
218
- if self.server_args.log_requests:
235
+ if self.log_requests:
219
236
  logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")
220
237
 
221
238
  async with self.model_update_lock.reader_lock:
@@ -336,7 +353,7 @@ class TokenizerManager:
336
353
 
337
354
  state.out_list = []
338
355
  if state.finished:
339
- if self.server_args.log_requests:
356
+ if self.log_requests:
340
357
  msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
341
358
  logger.info(msg)
342
359
  del self.rid_to_state[obj.rid]
@@ -548,6 +565,22 @@ class TokenizerManager:
548
565
  else:
549
566
  return all_parameters
550
567
 
568
+ async def release_memory_occupation(
569
+ self,
570
+ obj: ReleaseMemoryOccupationReqInput,
571
+ request: Optional[fastapi.Request] = None,
572
+ ):
573
+ self.auto_create_handle_loop()
574
+ await self.release_memory_occupation_communicator(obj)
575
+
576
+ async def resume_memory_occupation(
577
+ self,
578
+ obj: ResumeMemoryOccupationReqInput,
579
+ request: Optional[fastapi.Request] = None,
580
+ ):
581
+ self.auto_create_handle_loop()
582
+ await self.resume_memory_occupation_communicator(obj)
583
+
551
584
  async def open_session(
552
585
  self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
553
586
  ):
@@ -571,6 +604,15 @@ class TokenizerManager:
571
604
  assert not self.to_create_loop, "close session should not be the first request"
572
605
  await self.send_to_scheduler.send_pyobj(obj)
573
606
 
607
+ def configure_logging(self, obj: ConfigureLoggingReq):
608
+ if obj.log_requests is not None:
609
+ self.log_requests = obj.log_requests
610
+ if obj.dump_requests_folder is not None:
611
+ self.dump_requests_folder = obj.dump_requests_folder
612
+ if obj.dump_requests_threshold is not None:
613
+ self.dump_requests_threshold = obj.dump_requests_threshold
614
+ logging.info(f"Config logging: {obj=}")
615
+
574
616
  def create_abort_task(self, obj: GenerateReqInput):
575
617
  # Abort the request if the client is disconnected.
576
618
  async def abort_request():
@@ -601,7 +643,7 @@ class TokenizerManager:
601
643
  while not self.gracefully_exit:
602
644
  await asyncio.sleep(5)
603
645
 
604
- # drain requests
646
+ # Drain requests
605
647
  while True:
606
648
  remain_num_req = len(self.rid_to_state)
607
649
  logger.info(
@@ -627,6 +669,8 @@ class TokenizerManager:
627
669
  UpdateWeightsFromDistributedReqOutput,
628
670
  GetWeightsByNameReqOutput,
629
671
  InitWeightsUpdateGroupReqOutput,
672
+ ReleaseMemoryOccupationReqOutput,
673
+ ResumeMemoryOccupationReqOutput,
630
674
  ] = await self.recv_from_detokenizer.recv_pyobj()
631
675
 
632
676
  if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
@@ -663,13 +707,6 @@ class TokenizerManager:
663
707
  "text": recv_obj.output_strs[i],
664
708
  "meta_info": meta_info,
665
709
  }
666
- if self.server_args.return_token_ids:
667
- out_dict.update(
668
- {
669
- "input_ids": recv_obj.origin_input_ids[i],
670
- "output_ids": recv_obj.output_ids[i],
671
- }
672
- )
673
710
  elif isinstance(recv_obj, BatchTokenIDOut):
674
711
  out_dict = {
675
712
  "token_ids": recv_obj.output_ids[i],
@@ -686,45 +723,9 @@ class TokenizerManager:
686
723
  state.event.set()
687
724
 
688
725
  if self.enable_metrics:
689
- completion_tokens = (
690
- recv_obj.completion_tokens[i]
691
- if getattr(recv_obj, "completion_tokens", None)
692
- else 0
693
- )
694
-
695
- if state.first_token_time is None:
696
- state.first_token_time = time.time()
697
- self.metrics_collector.observe_time_to_first_token(
698
- state.first_token_time - state.created_time
699
- )
700
- else:
701
- if completion_tokens >= 2:
702
- # Compute time_per_output_token for the streaming case
703
- self.metrics_collector.observe_time_per_output_token(
704
- (time.time() - state.first_token_time)
705
- / (completion_tokens - 1)
706
- )
707
-
708
- if state.finished:
709
- self.metrics_collector.inc_prompt_tokens(
710
- recv_obj.prompt_tokens[i]
711
- )
712
- self.metrics_collector.inc_generation_tokens(
713
- completion_tokens
714
- )
715
- self.metrics_collector.observe_e2e_request_latency(
716
- time.time() - state.created_time
717
- )
718
- # Compute time_per_output_token for the non-streaming case
719
- if (
720
- hasattr(state.obj, "stream")
721
- and not state.obj.stream
722
- and completion_tokens >= 1
723
- ):
724
- self.metrics_collector.observe_time_per_output_token(
725
- (time.time() - state.created_time)
726
- / completion_tokens
727
- )
726
+ self.collect_metrics(state, recv_obj, i)
727
+ if self.dump_requests_folder and state.finished:
728
+ self.dump_requests(state, out_dict)
728
729
  elif isinstance(recv_obj, OpenSessionReqOutput):
729
730
  self.session_futures[recv_obj.session_id].set_result(
730
731
  recv_obj.session_id if recv_obj.success else None
@@ -754,6 +755,10 @@ class TokenizerManager:
754
755
  self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
755
756
  elif isinstance(recv_obj, GetWeightsByNameReqOutput):
756
757
  self.get_weights_by_name_communicator.handle_recv(recv_obj)
758
+ elif isinstance(recv_obj, ReleaseMemoryOccupationReqOutput):
759
+ self.release_memory_occupation_communicator.handle_recv(recv_obj)
760
+ elif isinstance(recv_obj, ResumeMemoryOccupationReqOutput):
761
+ self.resume_memory_occupation_communicator.handle_recv(recv_obj)
757
762
  else:
758
763
  raise ValueError(f"Invalid object: {recv_obj=}")
759
764
 
@@ -827,6 +832,61 @@ class TokenizerManager:
827
832
  ret.append(None)
828
833
  return ret
829
834
 
835
+ def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
836
+ completion_tokens = (
837
+ recv_obj.completion_tokens[i]
838
+ if getattr(recv_obj, "completion_tokens", None)
839
+ else 0
840
+ )
841
+
842
+ if state.first_token_time is None:
843
+ state.first_token_time = time.time()
844
+ self.metrics_collector.observe_time_to_first_token(
845
+ state.first_token_time - state.created_time
846
+ )
847
+ else:
848
+ if completion_tokens >= 2:
849
+ # Compute time_per_output_token for the streaming case
850
+ self.metrics_collector.observe_time_per_output_token(
851
+ (time.time() - state.first_token_time) / (completion_tokens - 1)
852
+ )
853
+
854
+ if state.finished:
855
+ self.metrics_collector.observe_one_finished_request(
856
+ recv_obj.prompt_tokens[i], completion_tokens
857
+ )
858
+ self.metrics_collector.observe_e2e_request_latency(
859
+ time.time() - state.created_time
860
+ )
861
+ # Compute time_per_output_token for the non-streaming case
862
+ if (
863
+ hasattr(state.obj, "stream")
864
+ and not state.obj.stream
865
+ and completion_tokens >= 1
866
+ ):
867
+ self.metrics_collector.observe_time_per_output_token(
868
+ (time.time() - state.created_time) / completion_tokens
869
+ )
870
+
871
+ def dump_requests(self, state: ReqState, out_dict: dict):
872
+ self.dump_request_list.append(
873
+ (state.obj, out_dict, state.created_time, time.time())
874
+ )
875
+
876
+ if len(self.dump_request_list) >= self.dump_requests_threshold:
877
+ to_dump = self.dump_request_list
878
+ self.dump_request_list = []
879
+
880
+ def background_task():
881
+ os.makedirs(self.dump_requests_folder, exist_ok=True)
882
+ current_time = datetime.now()
883
+ filename = current_time.strftime("%Y-%m-%d_%H-%M-%S") + ".pkl"
884
+ with open(os.path.join(self.dump_requests_folder, filename), "wb") as f:
885
+ pickle.dump(to_dump, f)
886
+
887
+ # Schedule the task to run in the background without awaiting it
888
+ asyncio.create_task(asyncio.to_thread(background_task))
889
+
830
890
 
831
891
  class SignalHandler:
832
892
  def __init__(self, tokenizer_manager):
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
+ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
17
+
16
18
  """
17
19
  Memory pool.
18
20
 
@@ -27,6 +29,7 @@ from enum import IntEnum
27
29
  from functools import wraps
28
30
  from typing import List, Tuple, Union
29
31
 
32
+ import numpy as np
30
33
  import psutil
31
34
  import torch
32
35
 
@@ -35,17 +38,31 @@ from sglang.srt.utils import debug_timing, get_compiler_backend
35
38
 
36
39
  logger = logging.getLogger(__name__)
37
40
 
41
+ GB = 1024 * 1024 * 1024
42
+
38
43
 
39
44
  class ReqToTokenPool:
40
45
  """A memory pool that maps a request to its token locations."""
41
46
 
42
- def __init__(self, size: int, max_context_len: int, device: str, use_records: bool):
47
+ def __init__(
48
+ self,
49
+ size: int,
50
+ max_context_len: int,
51
+ device: str,
52
+ use_records: bool,
53
+ enable_memory_saver: bool,
54
+ ):
55
+ memory_saver_adapter = TorchMemorySaverAdapter.create(
56
+ enable=enable_memory_saver
57
+ )
58
+
43
59
  self.size = size
44
60
  self.max_context_len = max_context_len
45
61
  self.device = device
46
- self.req_to_token = torch.zeros(
47
- (size, max_context_len), dtype=torch.int32, device=device
48
- )
62
+ with memory_saver_adapter.region():
63
+ self.req_to_token = torch.zeros(
64
+ (size, max_context_len), dtype=torch.int32, device=device
65
+ )
49
66
  self.free_slots = list(range(size))
50
67
  self.write_records = []
51
68
  self.use_records = use_records
@@ -109,8 +126,8 @@ class BaseTokenToKVPool:
109
126
  ):
110
127
  self.size = size
111
128
  self.dtype = dtype
112
- if dtype == torch.float8_e5m2:
113
- # NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
129
+ if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
130
+ # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
114
131
  self.store_dtype = torch.uint8
115
132
  else:
116
133
  self.store_dtype = dtype
@@ -186,37 +203,60 @@ class MHATokenToKVPool(BaseTokenToKVPool):
186
203
  head_dim: int,
187
204
  layer_num: int,
188
205
  device: str,
206
+ enable_memory_saver: bool,
189
207
  ):
190
208
  super().__init__(size, dtype, device)
209
+
210
+ self.memory_saver_adapter = TorchMemorySaverAdapter.create(
211
+ enable=enable_memory_saver
212
+ )
213
+
191
214
  self.head_num = head_num
192
215
  self.head_dim = head_dim
193
216
  self.layer_num = layer_num
194
217
  self._create_buffers()
195
218
 
219
+ k_size, v_size = self.get_kv_size_bytes()
220
+ logger.info(
221
+ f"KV Cache is allocated. K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB."
222
+ )
223
+
196
224
  def _create_buffers(self):
197
- # [size, head_num, head_dim] for each layer
198
- # The padded slot 0 is used for writing dummy outputs from padded tokens.
199
- self.k_buffer = [
200
- torch.empty(
201
- (self.size + 1, self.head_num, self.head_dim),
202
- dtype=self.store_dtype,
203
- device=self.device,
204
- )
205
- for _ in range(self.layer_num)
206
- ]
207
- self.v_buffer = [
208
- torch.empty(
209
- (self.size + 1, self.head_num, self.head_dim),
210
- dtype=self.store_dtype,
211
- device=self.device,
212
- )
213
- for _ in range(self.layer_num)
214
- ]
225
+ with self.memory_saver_adapter.region():
226
+ # [size, head_num, head_dim] for each layer
227
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
228
+ self.k_buffer = [
229
+ torch.empty(
230
+ (self.size + 1, self.head_num, self.head_dim),
231
+ dtype=self.store_dtype,
232
+ device=self.device,
233
+ )
234
+ for _ in range(self.layer_num)
235
+ ]
236
+ self.v_buffer = [
237
+ torch.empty(
238
+ (self.size + 1, self.head_num, self.head_dim),
239
+ dtype=self.store_dtype,
240
+ device=self.device,
241
+ )
242
+ for _ in range(self.layer_num)
243
+ ]
215
244
 
216
245
  def _clear_buffers(self):
217
246
  del self.k_buffer
218
247
  del self.v_buffer
219
248
 
249
+ def get_kv_size_bytes(self):
250
+ assert hasattr(self, "k_buffer")
251
+ assert hasattr(self, "v_buffer")
252
+ k_size_bytes = 0
253
+ for k_cache in self.k_buffer:
254
+ k_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
255
+ v_size_bytes = 0
256
+ for v_cache in self.v_buffer:
257
+ v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
258
+ return k_size_bytes, v_size_bytes
259
+
220
260
  # Todo: different memory layout
221
261
  def get_flat_data(self, indices):
222
262
  # prepare a large chunk of contiguous data for efficient transfer
@@ -256,11 +296,13 @@ class MHATokenToKVPool(BaseTokenToKVPool):
256
296
  loc: torch.Tensor,
257
297
  cache_k: torch.Tensor,
258
298
  cache_v: torch.Tensor,
299
+ k_scale: float = 1.0,
300
+ v_scale: float = 1.0,
259
301
  ):
260
302
  layer_id = layer.layer_id
261
303
  if cache_k.dtype != self.dtype:
262
- cache_k = cache_k.to(self.dtype)
263
- cache_v = cache_v.to(self.dtype)
304
+ cache_k = (cache_k / k_scale).to(self.dtype)
305
+ cache_v = (cache_v / v_scale).to(self.dtype)
264
306
  if self.store_dtype != self.dtype:
265
307
  self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
266
308
  self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
@@ -286,19 +328,26 @@ class MLATokenToKVPool(BaseTokenToKVPool):
286
328
  qk_rope_head_dim: int,
287
329
  layer_num: int,
288
330
  device: str,
331
+ enable_memory_saver: bool,
289
332
  ):
290
333
  super().__init__(size, dtype, device)
291
334
 
292
335
  self.kv_lora_rank = kv_lora_rank
293
- # The padded slot 0 is used for writing dummy outputs from padded tokens.
294
- self.kv_buffer = [
295
- torch.empty(
296
- (size + 1, 1, kv_lora_rank + qk_rope_head_dim),
297
- dtype=self.store_dtype,
298
- device=device,
299
- )
300
- for _ in range(layer_num)
301
- ]
336
+
337
+ memory_saver_adapter = TorchMemorySaverAdapter.create(
338
+ enable=enable_memory_saver
339
+ )
340
+
341
+ with memory_saver_adapter.region():
342
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
343
+ self.kv_buffer = [
344
+ torch.empty(
345
+ (size + 1, 1, kv_lora_rank + qk_rope_head_dim),
346
+ dtype=self.store_dtype,
347
+ device=device,
348
+ )
349
+ for _ in range(layer_num)
350
+ ]
302
351
 
303
352
  def get_key_buffer(self, layer_id: int):
304
353
  if self.store_dtype != self.dtype:
@@ -339,26 +388,32 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
339
388
  layer_num: int,
340
389
  device: str,
341
390
  heavy_channel_num: int,
391
+ enable_memory_saver: bool,
342
392
  ):
343
393
  super().__init__(size, dtype, device)
344
394
 
345
- # [size, head_num, head_dim] for each layer
346
- self.k_buffer = [
347
- torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
348
- for _ in range(layer_num)
349
- ]
350
- self.v_buffer = [
351
- torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
352
- for _ in range(layer_num)
353
- ]
354
-
355
- # [size, head_num, heavy_channel_num] for each layer
356
- self.label_buffer = [
357
- torch.empty(
358
- (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
359
- )
360
- for _ in range(layer_num)
361
- ]
395
+ memory_saver_adapter = TorchMemorySaverAdapter.create(
396
+ enable=enable_memory_saver
397
+ )
398
+
399
+ with memory_saver_adapter.region():
400
+ # [size, head_num, head_dim] for each layer
401
+ self.k_buffer = [
402
+ torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
403
+ for _ in range(layer_num)
404
+ ]
405
+ self.v_buffer = [
406
+ torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
407
+ for _ in range(layer_num)
408
+ ]
409
+
410
+ # [size, head_num, heavy_channel_num] for each layer
411
+ self.label_buffer = [
412
+ torch.empty(
413
+ (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
414
+ )
415
+ for _ in range(layer_num)
416
+ ]
362
417
 
363
418
  def get_key_buffer(self, layer_id: int):
364
419
  return self.k_buffer[layer_id]
@@ -109,6 +109,12 @@ class TokenizerMetricsCollector:
109
109
  labelnames=labels.keys(),
110
110
  )
111
111
 
112
+ self.num_requests_total = Counter(
113
+ name="sglang:num_requests_total",
114
+ documentation="Number of requests processed.",
115
+ labelnames=labels.keys(),
116
+ )
117
+
112
118
  self.histogram_time_to_first_token = Histogram(
113
119
  name="sglang:time_to_first_token_seconds",
114
120
  documentation="Histogram of time to first token in seconds.",
@@ -185,11 +191,10 @@ class TokenizerMetricsCollector:
185
191
  # Convenience function for logging to counter.
186
192
  counter.labels(**self.labels).inc(data)
187
193
 
188
- def inc_prompt_tokens(self, value: int):
189
- self._log_counter(self.prompt_tokens_total, value)
190
-
191
- def inc_generation_tokens(self, value: int):
192
- self._log_counter(self.generation_tokens_total, value)
194
+ def observe_one_finished_request(self, prompt_tokens: int, generation_tokens: int):
195
+ self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
196
+ self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
197
+ self.num_requests_total.labels(**self.labels).inc(1)
193
198
 
194
199
  def observe_time_to_first_token(self, value: Union[float, int]):
195
200
  self._log_histogram(self.histogram_time_to_first_token, value)