sglang 0.4.5.post2__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. sglang/bench_serving.py +3 -2
  2. sglang/compile_deep_gemm.py +136 -0
  3. sglang/lang/backend/openai.py +5 -1
  4. sglang/lang/backend/runtime_endpoint.py +5 -1
  5. sglang/srt/configs/model_config.py +4 -1
  6. sglang/srt/constrained/xgrammar_backend.py +1 -0
  7. sglang/srt/disaggregation/decode.py +43 -0
  8. sglang/srt/disaggregation/mini_lb.py +69 -8
  9. sglang/srt/disaggregation/mooncake/conn.py +1 -1
  10. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  11. sglang/srt/disaggregation/nixl/conn.py +622 -0
  12. sglang/srt/disaggregation/prefill.py +100 -16
  13. sglang/srt/disaggregation/utils.py +17 -0
  14. sglang/srt/entrypoints/engine.py +4 -0
  15. sglang/srt/entrypoints/http_server.py +3 -7
  16. sglang/srt/function_call_parser.py +60 -0
  17. sglang/srt/layers/activation.py +2 -2
  18. sglang/srt/layers/attention/flashattention_backend.py +781 -150
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
  21. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  22. sglang/srt/layers/dp_attention.py +1 -1
  23. sglang/srt/layers/layernorm.py +19 -4
  24. sglang/srt/layers/moe/ep_moe/layer.py +2 -0
  25. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  26. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  27. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  28. sglang/srt/layers/quantization/deep_gemm.py +378 -0
  29. sglang/srt/layers/quantization/fp8_kernel.py +7 -38
  30. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  31. sglang/srt/layers/quantization/gptq.py +13 -7
  32. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  33. sglang/srt/layers/quantization/w8a8_int8.py +3 -3
  34. sglang/srt/layers/rotary_embedding.py +6 -6
  35. sglang/srt/layers/sampler.py +2 -2
  36. sglang/srt/managers/data_parallel_controller.py +7 -1
  37. sglang/srt/managers/io_struct.py +14 -3
  38. sglang/srt/managers/schedule_batch.py +13 -0
  39. sglang/srt/managers/scheduler.py +16 -6
  40. sglang/srt/managers/tokenizer_manager.py +115 -29
  41. sglang/srt/managers/tp_worker.py +1 -0
  42. sglang/srt/mem_cache/hiradix_cache.py +40 -32
  43. sglang/srt/mem_cache/memory_pool.py +31 -13
  44. sglang/srt/model_executor/cuda_graph_runner.py +13 -8
  45. sglang/srt/model_executor/model_runner.py +19 -4
  46. sglang/srt/models/deepseek_v2.py +9 -6
  47. sglang/srt/models/minicpm3.py +2 -2
  48. sglang/srt/models/minicpmo.py +17 -6
  49. sglang/srt/openai_api/adapter.py +71 -4
  50. sglang/srt/openai_api/protocol.py +6 -1
  51. sglang/srt/server_args.py +52 -40
  52. sglang/srt/speculative/build_eagle_tree.py +2 -2
  53. sglang/srt/speculative/eagle_utils.py +2 -2
  54. sglang/srt/speculative/eagle_worker.py +2 -7
  55. sglang/srt/utils.py +46 -5
  56. sglang/test/test_utils.py +3 -1
  57. sglang/version.py +1 -1
  58. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +3 -3
  59. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +62 -57
  60. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +0 -0
  61. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
  62. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -96,8 +96,8 @@ class GenerateReqInput:
96
96
  return_hidden_states: bool = False
97
97
 
98
98
  # For disaggregated inference
99
- bootstrap_host: Optional[str] = None
100
- bootstrap_room: Optional[int] = None
99
+ bootstrap_host: Optional[Union[List[str], str]] = None
100
+ bootstrap_room: Optional[Union[List[int], int]] = None
101
101
 
102
102
  def normalize_batch_and_arguments(self):
103
103
  """
@@ -397,6 +397,12 @@ class GenerateReqInput:
397
397
  else None
398
398
  ),
399
399
  return_hidden_states=self.return_hidden_states,
400
+ bootstrap_host=(
401
+ self.bootstrap_host[i] if self.bootstrap_host is not None else None
402
+ ),
403
+ bootstrap_room=(
404
+ self.bootstrap_room[i] if self.bootstrap_room is not None else None
405
+ ),
400
406
  )
401
407
 
402
408
 
@@ -665,10 +671,15 @@ class BatchEmbeddingOut:
665
671
 
666
672
 
667
673
  @dataclass
668
- class FlushCacheReq:
674
+ class FlushCacheReqInput:
669
675
  pass
670
676
 
671
677
 
678
+ @dataclass
679
+ class FlushCacheReqOutput:
680
+ success: bool
681
+
682
+
672
683
  @dataclass
673
684
  class UpdateWeightFromDiskReqInput:
674
685
  # The model path with the new weights
@@ -539,6 +539,11 @@ class Req:
539
539
  # The first output_id transferred from prefill instance.
540
540
  self.transferred_output_id: Optional[int] = None
541
541
 
542
+ # For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
543
+ # This is because kv is not ready in `process_prefill_chunk`.
544
+ # We use `tmp_end_idx` to store the end index of the kv cache to send.
545
+ self.tmp_end_idx: int = -1
546
+
542
547
  @property
543
548
  def seqlen(self):
544
549
  return len(self.origin_input_ids) + len(self.output_ids)
@@ -571,6 +576,14 @@ class Req:
571
576
  self.prefix_indices, self.last_node = tree_cache.match_prefix(
572
577
  rid=self.rid, key=self.adjust_max_prefix_ids()
573
578
  )
579
+ elif enable_hierarchical_cache:
580
+ # in case last_node is evicted during scheduling, we need to update the prefix_indices
581
+ while self.last_node.evicted:
582
+ self.prefix_indices = self.prefix_indices[
583
+ : -len(self.last_node.host_value)
584
+ ]
585
+ self.last_node = self.last_node.parent
586
+
574
587
  self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
575
588
 
576
589
  def adjust_max_prefix_ids(self):
@@ -60,7 +60,8 @@ from sglang.srt.managers.io_struct import (
60
60
  CloseSessionReqInput,
61
61
  ExpertDistributionReq,
62
62
  ExpertDistributionReqOutput,
63
- FlushCacheReq,
63
+ FlushCacheReqInput,
64
+ FlushCacheReqOutput,
64
65
  GetInternalStateReq,
65
66
  GetInternalStateReqOutput,
66
67
  GetWeightsByNameReqInput,
@@ -402,7 +403,7 @@ class Scheduler(
402
403
  [
403
404
  (TokenizedGenerateReqInput, self.handle_generate_request),
404
405
  (TokenizedEmbeddingReqInput, self.handle_embedding_request),
405
- (FlushCacheReq, self.flush_cache_wrapped),
406
+ (FlushCacheReqInput, self.flush_cache_wrapped),
406
407
  (AbortReq, self.abort_request),
407
408
  (OpenSessionReqInput, self.open_session),
408
409
  (CloseSessionReqInput, self.close_session),
@@ -488,6 +489,8 @@ class Scheduler(
488
489
  tp_cache_group=self.tp_cpu_group,
489
490
  page_size=self.page_size,
490
491
  hicache_ratio=server_args.hicache_ratio,
492
+ hicache_size=server_args.hicache_size,
493
+ hicache_write_policy=server_args.hicache_write_policy,
491
494
  )
492
495
  else:
493
496
  self.tree_cache = RadixCache(
@@ -1596,8 +1599,9 @@ class Scheduler(
1596
1599
  time.sleep(5)
1597
1600
  self.parent_process.send_signal(signal.SIGQUIT)
1598
1601
 
1599
- def flush_cache_wrapped(self, recv_req: FlushCacheReq):
1600
- self.flush_cache()
1602
+ def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
1603
+ success = self.flush_cache()
1604
+ return FlushCacheReqOutput(success=success)
1601
1605
 
1602
1606
  def flush_cache(self):
1603
1607
  """Flush the memory pool and cache."""
@@ -2010,9 +2014,15 @@ def run_scheduler_process(
2010
2014
  else:
2011
2015
  scheduler.event_loop_normal()
2012
2016
  elif disaggregation_mode == DisaggregationMode.PREFILL:
2013
- scheduler.event_loop_normal_disagg_prefill()
2017
+ if scheduler.enable_overlap:
2018
+ scheduler.event_loop_overlap_disagg_prefill()
2019
+ else:
2020
+ scheduler.event_loop_normal_disagg_prefill()
2014
2021
  elif disaggregation_mode == DisaggregationMode.DECODE:
2015
- scheduler.event_loop_normal_disagg_decode()
2022
+ if scheduler.enable_overlap:
2023
+ scheduler.event_loop_overlap_disagg_decode()
2024
+ else:
2025
+ scheduler.event_loop_normal_disagg_decode()
2016
2026
 
2017
2027
  except Exception:
2018
2028
  traceback = get_exception_traceback()
@@ -66,7 +66,8 @@ from sglang.srt.managers.io_struct import (
66
66
  EmbeddingReqInput,
67
67
  ExpertDistributionReq,
68
68
  ExpertDistributionReqOutput,
69
- FlushCacheReq,
69
+ FlushCacheReqInput,
70
+ FlushCacheReqOutput,
70
71
  GenerateReqInput,
71
72
  GetInternalStateReq,
72
73
  GetInternalStateReqOutput,
@@ -264,6 +265,9 @@ class TokenizerManager:
264
265
  self.resume_memory_occupation_communicator = _Communicator(
265
266
  self.send_to_scheduler, server_args.dp_size
266
267
  )
268
+ self.flush_cache_communicator = _Communicator(
269
+ self.send_to_scheduler, server_args.dp_size
270
+ )
267
271
  self.start_profile_communicator = _Communicator(
268
272
  self.send_to_scheduler, server_args.dp_size
269
273
  )
@@ -314,6 +318,10 @@ class TokenizerManager:
314
318
  ResumeMemoryOccupationReqOutput,
315
319
  self.resume_memory_occupation_communicator.handle_recv,
316
320
  ),
321
+ (
322
+ FlushCacheReqOutput,
323
+ self.flush_cache_communicator.handle_recv,
324
+ ),
317
325
  (
318
326
  ProfileReqOutput,
319
327
  self.start_profile_communicator.handle_recv,
@@ -415,38 +423,60 @@ class TokenizerManager:
415
423
  )
416
424
  if image_inputs and "input_ids" in image_inputs:
417
425
  input_ids = image_inputs["input_ids"]
418
- if self.is_generation:
419
- return_logprob = obj.return_logprob
420
- logprob_start_len = obj.logprob_start_len
421
- top_logprobs_num = obj.top_logprobs_num
422
- token_ids_logprob = obj.token_ids_logprob
423
- session_params = (
424
- SessionParams(**obj.session_params) if obj.session_params else None
425
- )
426
+
427
+ self._validate_token_len(obj, input_ids)
428
+ return self._create_tokenized_object(
429
+ obj, input_text, input_ids, input_embeds, image_inputs
430
+ )
431
+
432
+ def _validate_token_len(
433
+ self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
434
+ ) -> None:
435
+ """Validates that the input token count and the requested token count doesn't exceed the model's context length."""
426
436
 
427
437
  input_token_num = len(input_ids) if input_ids is not None else 0
438
+ # Check if input alone exceeds context length
428
439
  if input_token_num >= self.context_len:
429
440
  raise ValueError(
430
441
  f"The input ({input_token_num} tokens) is longer than the "
431
442
  f"model's context length ({self.context_len} tokens)."
432
443
  )
433
444
 
445
+ # Check total tokens (input + max_new_tokens)
446
+ max_new_tokens = obj.sampling_params.get("max_new_tokens")
434
447
  if (
435
- obj.sampling_params.get("max_new_tokens") is not None
436
- and obj.sampling_params.get("max_new_tokens") + input_token_num
437
- >= self.context_len
448
+ max_new_tokens is not None
449
+ and (max_new_tokens + input_token_num) >= self.context_len
438
450
  ):
439
- raise ValueError(
451
+ total_tokens = max_new_tokens + input_token_num
452
+ error_msg = (
440
453
  f"Requested token count exceeds the model's maximum context length "
441
- f"of {self.context_len} tokens. You requested a total of "
442
- f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
454
+ f"of {self.context_len} tokens. You requested a total of {total_tokens} "
443
455
  f"tokens: {input_token_num} tokens from the input messages and "
444
- f"{obj.sampling_params.get('max_new_tokens')} tokens for the "
445
- f"completion. Please reduce the number of tokens in the input "
446
- f"messages or the completion to fit within the limit."
456
+ f"{max_new_tokens} tokens for the completion. Please reduce the number "
457
+ f"of tokens in the input messages or the completion to fit within the limit."
458
+ )
459
+ raise ValueError(error_msg)
460
+
461
+ def _create_tokenized_object(
462
+ self,
463
+ obj: Union[GenerateReqInput, EmbeddingReqInput],
464
+ input_text: str,
465
+ input_ids: List[int],
466
+ input_embeds: Optional[Union[List[float], None]] = None,
467
+ image_inputs: Optional[Dict] = None,
468
+ ) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
469
+ """Create a tokenized request object from common parameters."""
470
+
471
+ if self.is_generation:
472
+ return_logprob = obj.return_logprob
473
+ logprob_start_len = obj.logprob_start_len
474
+ top_logprobs_num = obj.top_logprobs_num
475
+ token_ids_logprob = obj.token_ids_logprob
476
+ session_params = (
477
+ SessionParams(**obj.session_params) if obj.session_params else None
447
478
  )
448
479
 
449
- # Parse sampling parameters
450
480
  sampling_params = SamplingParams(**obj.sampling_params)
451
481
  sampling_params.normalize(self.tokenizer)
452
482
  sampling_params.verify()
@@ -483,6 +513,50 @@ class TokenizerManager:
483
513
 
484
514
  return tokenized_obj
485
515
 
516
+ async def _batch_tokenize_and_process(
517
+ self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
518
+ ) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]]:
519
+ """Handle batch tokenization for text inputs only."""
520
+ logger.debug(f"Starting batch tokenization for {batch_size} text requests")
521
+
522
+ # Collect requests and texts
523
+ requests = [obj[i] for i in range(batch_size)]
524
+ texts = [req.text for req in requests]
525
+
526
+ # Batch tokenize all texts
527
+ encoded = self.tokenizer(texts)
528
+ input_ids_list = encoded["input_ids"]
529
+
530
+ # Process all requests
531
+ tokenized_objs = []
532
+ for i, req in enumerate(requests):
533
+ self._validate_token_len(obj[i], input_ids_list[i])
534
+ tokenized_objs.append(
535
+ self._create_tokenized_object(
536
+ req, req.text, input_ids_list[i], None, None
537
+ )
538
+ )
539
+ logger.debug(f"Completed batch processing for {batch_size} requests")
540
+ return tokenized_objs
541
+
542
+ def _validate_batch_tokenization_constraints(
543
+ self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
544
+ ) -> None:
545
+ """Validate constraints for batch tokenization processing."""
546
+ for i in range(batch_size):
547
+ if self.is_generation and obj[i].image_data:
548
+ raise ValueError(
549
+ "For image input processing do not set `enable_tokenizer_batch_encode`."
550
+ )
551
+ if obj[i].input_ids is not None:
552
+ raise ValueError(
553
+ "Batch tokenization is not needed for pre-tokenized input_ids. Do not set `enable_tokenizer_batch_encode`."
554
+ )
555
+ if obj[i].input_embeds is not None:
556
+ raise ValueError(
557
+ "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
558
+ )
559
+
486
560
  def _send_one_request(
487
561
  self,
488
562
  obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -560,14 +634,27 @@ class TokenizerManager:
560
634
 
561
635
  generators = []
562
636
  rids = []
637
+
563
638
  if getattr(obj, "parallel_sample_num", 1) == 1:
564
- # Send all requests
565
- for i in range(batch_size):
566
- tmp_obj = obj[i]
567
- tokenized_obj = await self._tokenize_one_request(tmp_obj)
568
- self._send_one_request(tmp_obj, tokenized_obj, created_time)
569
- generators.append(self._wait_one_response(tmp_obj, request))
570
- rids.append(tmp_obj.rid)
639
+ if self.server_args.enable_tokenizer_batch_encode:
640
+ # Validate batch tokenization constraints
641
+ self._validate_batch_tokenization_constraints(batch_size, obj)
642
+
643
+ tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)
644
+
645
+ for i, tokenized_obj in enumerate(tokenized_objs):
646
+ tmp_obj = obj[i]
647
+ self._send_one_request(tmp_obj, tokenized_obj, created_time)
648
+ generators.append(self._wait_one_response(tmp_obj, request))
649
+ rids.append(tmp_obj.rid)
650
+ else:
651
+ # Sequential tokenization and processing
652
+ for i in range(batch_size):
653
+ tmp_obj = obj[i]
654
+ tokenized_obj = await self._tokenize_one_request(tmp_obj)
655
+ self._send_one_request(tmp_obj, tokenized_obj, created_time)
656
+ generators.append(self._wait_one_response(tmp_obj, request))
657
+ rids.append(tmp_obj.rid)
571
658
  else:
572
659
  # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
573
660
  if batch_size > 128:
@@ -628,9 +715,8 @@ class TokenizerManager:
628
715
  except StopAsyncIteration:
629
716
  pass
630
717
 
631
- def flush_cache(self):
632
- req = FlushCacheReq()
633
- self.send_to_scheduler.send_pyobj(req)
718
+ async def flush_cache(self) -> FlushCacheReqOutput:
719
+ return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
634
720
 
635
721
  def abort_request(self, rid: str):
636
722
  if rid not in self.rid_to_state:
@@ -116,6 +116,7 @@ class TpModelWorker:
116
116
  ),
117
117
  self.model_runner.req_to_token_pool.size,
118
118
  )
119
+ assert self.max_running_requests > 0, "max_running_request is zero"
119
120
  self.max_req_len = min(
120
121
  self.model_config.context_len - 1,
121
122
  self.max_total_num_tokens - 1,
@@ -29,15 +29,17 @@ class HiRadixCache(RadixCache):
29
29
  tp_cache_group: torch.distributed.ProcessGroup,
30
30
  page_size: int,
31
31
  hicache_ratio: float,
32
+ hicache_size: int,
33
+ hicache_write_policy: str,
32
34
  ):
33
35
  self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
34
36
  if isinstance(self.kv_cache, MHATokenToKVPool):
35
37
  self.token_to_kv_pool_host = MHATokenToKVPoolHost(
36
- self.kv_cache, hicache_ratio, page_size
38
+ self.kv_cache, hicache_ratio, hicache_size, page_size
37
39
  )
38
40
  elif isinstance(self.kv_cache, MLATokenToKVPool):
39
41
  self.token_to_kv_pool_host = MLATokenToKVPoolHost(
40
- self.kv_cache, hicache_ratio, page_size
42
+ self.kv_cache, hicache_ratio, hicache_size, page_size
41
43
  )
42
44
  else:
43
45
  raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
@@ -50,6 +52,7 @@ class HiRadixCache(RadixCache):
50
52
  self.token_to_kv_pool_host,
51
53
  page_size,
52
54
  load_cache_event=self.load_cache_event,
55
+ write_policy=hicache_write_policy,
53
56
  )
54
57
 
55
58
  # record the nodes with ongoing write through
@@ -57,7 +60,9 @@ class HiRadixCache(RadixCache):
57
60
  # record the node segments with ongoing load back
58
61
  self.ongoing_load_back = {}
59
62
  # todo: dynamically adjust the threshold
60
- self.write_through_threshold = 1
63
+ self.write_through_threshold = (
64
+ 1 if hicache_write_policy == "write_through" else 3
65
+ )
61
66
  self.load_back_threshold = 10
62
67
  super().__init__(
63
68
  req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
@@ -76,7 +81,7 @@ class HiRadixCache(RadixCache):
76
81
  height += 1
77
82
  return height
78
83
 
79
- def write_backup(self, node: TreeNode):
84
+ def write_backup(self, node: TreeNode, write_back=False):
80
85
  host_indices = self.cache_controller.write(
81
86
  device_indices=node.value,
82
87
  node_id=node.id,
@@ -90,21 +95,29 @@ class HiRadixCache(RadixCache):
90
95
  if host_indices is not None:
91
96
  node.host_value = host_indices
92
97
  self.ongoing_write_through[node.id] = node
93
- self.inc_lock_ref(node)
98
+ if not write_back:
99
+ # no need to lock nodes if write back
100
+ self.inc_lock_ref(node)
94
101
  else:
95
102
  return 0
96
103
 
97
104
  return len(host_indices)
98
105
 
99
106
  def inc_hit_count(self, node: TreeNode):
100
- if self.cache_controller.write_policy != "write_through_selective":
107
+ if node.backuped or self.cache_controller.write_policy == "write_back":
101
108
  return
102
109
  node.hit_count += 1
103
- if node.host_value is None and node.hit_count > self.write_through_threshold:
110
+ if node.hit_count >= self.write_through_threshold:
104
111
  self.write_backup(node)
105
112
  node.hit_count = 0
106
113
 
107
- def writing_check(self):
114
+ def writing_check(self, write_back=False):
115
+ if write_back:
116
+ # blocking till all write back complete
117
+ while len(self.ongoing_write_through) > 0:
118
+ ack_id = self.cache_controller.ack_write_queue.get()
119
+ del self.ongoing_write_through[ack_id]
120
+ return
108
121
  queue_size = torch.tensor(
109
122
  self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
110
123
  )
@@ -143,29 +156,25 @@ class HiRadixCache(RadixCache):
143
156
  heapq.heapify(leaves)
144
157
 
145
158
  num_evicted = 0
146
- pending_nodes = []
159
+ write_back_nodes = []
147
160
  while num_evicted < num_tokens and len(leaves):
148
161
  x = heapq.heappop(leaves)
149
162
 
150
163
  if x.lock_ref > 0:
151
164
  continue
152
165
 
153
- if x.host_value is None:
166
+ if not x.backuped:
154
167
  if self.cache_controller.write_policy == "write_back":
155
- num_evicted += self.write_backup(x)
156
- pending_nodes.append(x)
157
- elif self.cache_controller.write_policy == "write_through_selective":
158
- num_evicted += self._evict_write_through_selective(x)
168
+ # write to host if the node is not backuped
169
+ num_evicted += self.write_backup(x, write_back=True)
170
+ write_back_nodes.append(x)
159
171
  else:
160
- assert (
161
- self.cache_controller.write_policy != "write_through"
162
- ), "write_through should be inclusive"
163
- raise NotImplementedError
172
+ num_evicted += self._evict_regular(x)
164
173
  else:
165
- num_evicted += self._evict_write_through(x)
174
+ num_evicted += self._evict_backuped(x)
166
175
 
167
176
  for child in x.parent.children.values():
168
- if child in pending_nodes:
177
+ if child in write_back_nodes:
169
178
  continue
170
179
  if not child.evicted:
171
180
  break
@@ -174,15 +183,12 @@ class HiRadixCache(RadixCache):
174
183
  heapq.heappush(leaves, x.parent)
175
184
 
176
185
  if self.cache_controller.write_policy == "write_back":
177
- # blocking till all write back complete
178
- while len(self.ongoing_write_through) > 0:
179
- self.writing_check()
180
- time.sleep(0.1)
181
- for node in pending_nodes:
182
- assert node.host_value is not None
183
- self._evict_write_through(node)
186
+ self.writing_check(write_back=True)
187
+ for node in write_back_nodes:
188
+ assert node.backuped
189
+ self._evict_backuped(node)
184
190
 
185
- def _evict_write_through(self, node: TreeNode):
191
+ def _evict_backuped(self, node: TreeNode):
186
192
  # evict a node already written to host
187
193
  num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
188
194
  assert num_evicted > 0
@@ -190,7 +196,7 @@ class HiRadixCache(RadixCache):
190
196
  node.value = None
191
197
  return num_evicted
192
198
 
193
- def _evict_write_through_selective(self, node: TreeNode):
199
+ def _evict_regular(self, node: TreeNode):
194
200
  # evict a node not initiated write to host
195
201
  self.cache_controller.mem_pool_device_allocator.free(node.value)
196
202
  num_evicted = len(node.value)
@@ -339,11 +345,13 @@ class HiRadixCache(RadixCache):
339
345
  prefix_len = self.key_match_fn(child.key, key)
340
346
  if prefix_len < len(child.key):
341
347
  new_node = self._split_node(child.key, child, prefix_len)
348
+ self.inc_hit_count(new_node)
342
349
  if not new_node.evicted:
343
350
  value.append(new_node.value)
344
351
  node = new_node
345
352
  break
346
353
  else:
354
+ self.inc_hit_count(child)
347
355
  if not child.evicted:
348
356
  value.append(child.value)
349
357
  node = child
@@ -369,7 +377,7 @@ class HiRadixCache(RadixCache):
369
377
  else:
370
378
  new_node.value = child.value[:split_len]
371
379
  child.value = child.value[split_len:]
372
- if child.host_value is not None:
380
+ if child.backuped:
373
381
  new_node.host_value = child.host_value[:split_len]
374
382
  child.host_value = child.host_value[split_len:]
375
383
  child.parent = new_node
@@ -426,8 +434,8 @@ class HiRadixCache(RadixCache):
426
434
  node.children[child_key] = new_node
427
435
  self.evictable_size_ += len(value)
428
436
 
429
- if self.cache_controller.write_policy == "write_through":
430
- self.write_backup(new_node)
437
+ if self.cache_controller.write_policy != "write_back":
438
+ self.inc_hit_count(new_node)
431
439
  return total_prefix_length
432
440
 
433
441
  def _collect_leaves_device(self):
@@ -446,13 +446,28 @@ class MLATokenToKVPool(KVCache):
446
446
  ]
447
447
 
448
448
  self.layer_transfer_counter = None
449
+ self.page_size = page_size
450
+
451
+ kv_size = self.get_kv_size_bytes()
452
+ logger.info(
453
+ f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
454
+ )
455
+
456
+ def get_kv_size_bytes(self):
457
+ assert hasattr(self, "kv_buffer")
458
+ kv_size_bytes = 0
459
+ for kv_cache in self.kv_buffer:
460
+ kv_size_bytes += np.prod(kv_cache.shape) * kv_cache.dtype.itemsize
461
+ return kv_size_bytes
449
462
 
450
463
  # for disagg
451
464
  def get_contiguous_buf_infos(self):
452
465
  # MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
453
466
  kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
454
467
  kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
455
- kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)]
468
+ kv_item_lens = [
469
+ self.kv_buffer[i][0].nbytes * self.page_size for i in range(self.layer_num)
470
+ ]
456
471
  return kv_data_ptrs, kv_data_lens, kv_item_lens
457
472
 
458
473
  def get_key_buffer(self, layer_id: int):
@@ -621,26 +636,27 @@ class HostKVCache(abc.ABC):
621
636
  self,
622
637
  device_pool: MHATokenToKVPool,
623
638
  host_to_device_ratio: float,
639
+ host_size: int,
624
640
  pin_memory: bool,
625
641
  device: str,
626
642
  page_size: int,
627
643
  ):
628
- assert (
629
- host_to_device_ratio >= 1
630
- ), "The host memory should be larger than the device memory with the current protocol"
631
- # todo, other ways of configuring the size
632
-
633
644
  self.device_pool = device_pool
634
- self.host_to_device_ratio = host_to_device_ratio
645
+ self.dtype = device_pool.store_dtype
635
646
  self.pin_memory = pin_memory
636
647
  self.device = device
637
648
  self.page_size = page_size
638
-
639
- self.size = int(device_pool.size * host_to_device_ratio)
649
+ self.size_per_token = self.get_size_per_token()
650
+ if host_size > 0:
651
+ self.size = int(host_size * 1e9 // self.size_per_token)
652
+ else:
653
+ self.size = int(device_pool.size * host_to_device_ratio)
640
654
  # Align the host memory pool size to the page size
641
655
  self.size = self.size - (self.size % self.page_size)
642
- self.dtype = device_pool.store_dtype
643
- self.size_per_token = self.get_size_per_token()
656
+
657
+ assert (
658
+ self.size > device_pool.size
659
+ ), "The host memory should be larger than the device memory with the current protocol"
644
660
 
645
661
  # Verify there is enough available host memory.
646
662
  host_mem = psutil.virtual_memory()
@@ -792,12 +808,13 @@ class MHATokenToKVPoolHost(HostKVCache):
792
808
  self,
793
809
  device_pool: MHATokenToKVPool,
794
810
  host_to_device_ratio: float,
811
+ host_size: int,
795
812
  page_size: int,
796
813
  pin_memory: bool = True,
797
814
  device: str = "cpu",
798
815
  ):
799
816
  super().__init__(
800
- device_pool, host_to_device_ratio, pin_memory, device, page_size
817
+ device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
801
818
  )
802
819
 
803
820
  def get_size_per_token(self):
@@ -866,12 +883,13 @@ class MLATokenToKVPoolHost(HostKVCache):
866
883
  self,
867
884
  device_pool: MLATokenToKVPool,
868
885
  host_to_device_ratio: float,
886
+ host_size: int,
869
887
  page_size: int,
870
888
  pin_memory: bool = True,
871
889
  device: str = "cpu",
872
890
  ):
873
891
  super().__init__(
874
- device_pool, host_to_device_ratio, pin_memory, device, page_size
892
+ device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
875
893
  )
876
894
 
877
895
  def get_size_per_token(self):
@@ -35,7 +35,11 @@ from sglang.srt.model_executor.forward_batch_info import (
35
35
  ForwardMode,
36
36
  )
37
37
  from sglang.srt.patch_torch import monkey_patch_torch_compile
38
- from sglang.srt.utils import get_available_gpu_memory, is_hip
38
+ from sglang.srt.utils import (
39
+ get_available_gpu_memory,
40
+ get_device_memory_capacity,
41
+ is_hip,
42
+ )
39
43
 
40
44
  if TYPE_CHECKING:
41
45
  from sglang.srt.model_executor.model_runner import ModelRunner
@@ -129,7 +133,8 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
129
133
  list(range(1, 9)) + list(range(10, 33, 2)) + list(range(40, 161, 16))
130
134
  )
131
135
 
132
- if _is_hip:
136
+ gpu_mem = get_device_memory_capacity()
137
+ if gpu_mem is not None and gpu_mem > 81920:
133
138
  capture_bs += list(range(160, 257, 8))
134
139
 
135
140
  if max(capture_bs) > model_runner.req_to_token_pool.size:
@@ -140,12 +145,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
140
145
  ]
141
146
 
142
147
  capture_bs = list(sorted(set(capture_bs)))
143
- capture_bs = [
144
- bs
145
- for bs in capture_bs
146
- if bs <= model_runner.req_to_token_pool.size
147
- and bs <= server_args.cuda_graph_max_bs
148
- ]
148
+
149
+ assert len(capture_bs) > 0 and capture_bs[0] > 0
150
+ capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
151
+ if server_args.cuda_graph_max_bs:
152
+ capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
149
153
  compile_bs = (
150
154
  [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
151
155
  if server_args.enable_torch_compile
@@ -186,6 +190,7 @@ class CudaGraphRunner:
186
190
 
187
191
  # Batch sizes to capture
188
192
  self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
193
+
189
194
  self.capture_forward_mode = ForwardMode.DECODE
190
195
  self.capture_hidden_mode = CaptureHiddenMode.NULL
191
196
  self.num_tokens_per_bs = 1