sglang 0.5.0rc1__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. sglang/bench_one_batch.py +0 -1
  2. sglang/srt/configs/model_config.py +1 -0
  3. sglang/srt/disaggregation/decode.py +0 -1
  4. sglang/srt/entrypoints/engine.py +2 -2
  5. sglang/srt/entrypoints/http_server.py +64 -0
  6. sglang/srt/entrypoints/openai/protocol.py +2 -0
  7. sglang/srt/entrypoints/openai/serving_chat.py +1 -0
  8. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  9. sglang/srt/layers/attention/flashinfer_backend.py +3 -0
  10. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  11. sglang/srt/layers/attention/triton_backend.py +24 -27
  12. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  13. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -3
  14. sglang/srt/layers/communicator.py +7 -7
  15. sglang/srt/layers/dp_attention.py +118 -27
  16. sglang/srt/layers/logits_processor.py +12 -18
  17. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/multimodal.py +156 -40
  29. sglang/srt/layers/quantization/__init__.py +5 -32
  30. sglang/srt/layers/quantization/awq.py +15 -16
  31. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  32. sglang/srt/layers/quantization/gptq.py +12 -17
  33. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  34. sglang/srt/layers/quantization/modelopt_quant.py +52 -30
  35. sglang/srt/layers/quantization/mxfp4.py +16 -2
  36. sglang/srt/layers/quantization/utils.py +52 -2
  37. sglang/srt/layers/sampler.py +5 -2
  38. sglang/srt/lora/layers.py +6 -2
  39. sglang/srt/managers/cache_controller.py +4 -1
  40. sglang/srt/managers/io_struct.py +14 -0
  41. sglang/srt/managers/schedule_batch.py +18 -39
  42. sglang/srt/managers/scheduler.py +3 -4
  43. sglang/srt/managers/tokenizer_manager.py +28 -18
  44. sglang/srt/mem_cache/allocator.py +8 -157
  45. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  46. sglang/srt/mem_cache/chunk_cache.py +1 -1
  47. sglang/srt/model_executor/cuda_graph_runner.py +8 -21
  48. sglang/srt/model_executor/forward_batch_info.py +8 -10
  49. sglang/srt/model_executor/model_runner.py +57 -53
  50. sglang/srt/models/deepseek_nextn.py +2 -1
  51. sglang/srt/models/deepseek_v2.py +5 -3
  52. sglang/srt/models/glm4_moe.py +2 -2
  53. sglang/srt/models/glm4_moe_nextn.py +2 -1
  54. sglang/srt/models/gpt_oss.py +7 -2
  55. sglang/srt/models/llama.py +10 -2
  56. sglang/srt/models/llama4.py +18 -5
  57. sglang/srt/models/qwen2.py +2 -2
  58. sglang/srt/models/qwen2_moe.py +20 -5
  59. sglang/srt/models/qwen3_classification.py +78 -0
  60. sglang/srt/models/qwen3_moe.py +18 -5
  61. sglang/srt/models/step3_vl.py +6 -2
  62. sglang/srt/operations.py +17 -2
  63. sglang/srt/sampling/sampling_batch_info.py +7 -4
  64. sglang/srt/server_args.py +33 -7
  65. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  66. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  67. sglang/srt/two_batch_overlap.py +4 -8
  68. sglang/test/test_marlin_moe.py +1 -1
  69. sglang/test/test_marlin_utils.py +1 -1
  70. sglang/version.py +1 -1
  71. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +5 -5
  72. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +75 -63
  73. sglang/srt/layers/quantization/scalar_type.py +0 -352
  74. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  75. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  76. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -84,7 +84,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
84
84
  "device",
85
85
  "disable_chunked_prefix_cache",
86
86
  "disable_radix_cache",
87
- "enable_dp_attention",
88
87
  "enable_two_batch_overlap",
89
88
  "tbo_token_distribution_threshold",
90
89
  "enable_dp_lm_head",
@@ -113,6 +112,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
113
112
  "enable_multimodal",
114
113
  "enable_symm_mem",
115
114
  "quantization",
115
+ "enable_custom_logit_processor",
116
116
  ]
117
117
 
118
118
  # Put some global args for easy access
@@ -909,12 +909,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
909
909
  spec_algorithm: SpeculativeAlgorithm = None
910
910
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
911
911
 
912
- # Enable custom logit processor
913
- enable_custom_logit_processor: bool = False
914
-
915
912
  # Whether to return hidden states
916
913
  return_hidden_states: bool = False
917
914
 
915
+ # Whether this batch is prefill-only (no token generation needed)
916
+ is_prefill_only: bool = False
917
+
918
918
  # hicache pointer for synchronizing data loading from CPU to GPU
919
919
  hicache_consumer_index: int = 0
920
920
 
@@ -928,7 +928,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
928
928
  model_config: ModelConfig,
929
929
  enable_overlap: bool,
930
930
  spec_algorithm: SpeculativeAlgorithm,
931
- enable_custom_logit_processor: bool,
932
931
  chunked_req: Optional[Req] = None,
933
932
  ):
934
933
  return_logprob = any(req.return_logprob for req in reqs)
@@ -955,8 +954,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
955
954
  has_grammar=any(req.grammar for req in reqs),
956
955
  device=req_to_token_pool.device,
957
956
  spec_algorithm=spec_algorithm,
958
- enable_custom_logit_processor=enable_custom_logit_processor,
959
957
  return_hidden_states=any(req.return_hidden_states for req in reqs),
958
+ is_prefill_only=all(
959
+ req.sampling_params.max_new_tokens == 0 for req in reqs
960
+ ),
960
961
  chunked_req=chunked_req,
961
962
  )
962
963
 
@@ -1009,6 +1010,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1009
1010
  extend_num_tokens: int,
1010
1011
  backup_state: bool = False,
1011
1012
  ):
1013
+ # Over estimate the number of tokens: assume each request needs a new page.
1012
1014
  num_tokens = (
1013
1015
  extend_num_tokens
1014
1016
  + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
@@ -1041,8 +1043,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1041
1043
  last_loc: torch.Tensor,
1042
1044
  backup_state: bool = False,
1043
1045
  ):
1046
+ # Over estimate the number of tokens: assume each request needs a new page.
1044
1047
  num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
1045
-
1046
1048
  self._evict_tree_cache_if_needed(num_tokens)
1047
1049
 
1048
1050
  if backup_state:
@@ -1721,38 +1723,18 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1721
1723
  extend_prefix_lens = self.prefix_lens
1722
1724
  extend_logprob_start_lens = self.extend_logprob_start_lens
1723
1725
 
1724
- if self.forward_mode.is_decode_or_idle():
1725
- attention_backend_str = global_server_args_dict["decode_attention_backend"]
1726
- else:
1727
- attention_backend_str = global_server_args_dict["prefill_attention_backend"]
1728
- # Create seq_lens_cpu when needed
1729
- if (
1730
- attention_backend_str
1731
- in [
1732
- "fa3",
1733
- "flashinfer",
1734
- "flashmla",
1735
- "cutlass_mla",
1736
- "ascend",
1737
- "trtllm_mha",
1738
- "aiter",
1739
- ]
1740
- or global_server_args_dict["enable_two_batch_overlap"]
1741
- ):
1742
- seq_lens_cpu = (
1743
- seq_lens_cpu_cache
1744
- if seq_lens_cpu_cache is not None
1745
- else self.seq_lens.cpu()
1746
- )
1747
- else:
1748
- seq_lens_cpu = None
1749
-
1750
1726
  if self.sampling_info:
1751
1727
  if self.has_grammar:
1752
1728
  self.sampling_info.grammars = [req.grammar for req in self.reqs]
1753
1729
  else:
1754
1730
  self.sampling_info.grammars = None
1755
1731
 
1732
+ seq_lens_cpu = (
1733
+ seq_lens_cpu_cache
1734
+ if seq_lens_cpu_cache is not None
1735
+ else self.seq_lens.cpu()
1736
+ )
1737
+
1756
1738
  global bid
1757
1739
  bid += 1
1758
1740
  return ModelWorkerBatch(
@@ -1815,18 +1797,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1815
1797
  return_logprob=self.return_logprob,
1816
1798
  decoding_reqs=self.decoding_reqs,
1817
1799
  spec_algorithm=self.spec_algorithm,
1818
- enable_custom_logit_processor=self.enable_custom_logit_processor,
1819
1800
  global_num_tokens=self.global_num_tokens,
1820
1801
  global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1821
1802
  can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1822
1803
  is_extend_in_batch=self.is_extend_in_batch,
1804
+ is_prefill_only=self.is_prefill_only,
1823
1805
  )
1824
1806
 
1825
- def _evict_tree_cache_if_needed(
1826
- self,
1827
- num_tokens: int,
1828
- ) -> None:
1829
- if isinstance(self.tree_cache, SWAChunkCache):
1807
+ def _evict_tree_cache_if_needed(self, num_tokens: int):
1808
+ if isinstance(self.tree_cache, (SWAChunkCache, ChunkCache)):
1830
1809
  return
1831
1810
 
1832
1811
  if self.is_hybrid:
@@ -1466,8 +1466,9 @@ class Scheduler(
1466
1466
  if self.last_batch.batch_size() < last_bs:
1467
1467
  self.running_batch.batch_is_full = False
1468
1468
 
1469
- # Merge the new batch into the running batch
1470
- if not self.last_batch.is_empty():
1469
+ # Merge the new batch into the running batch.
1470
+ # For prefill-only batch, we can avoid going through decoding step.
1471
+ if not self.last_batch.is_empty() and not self.last_batch.is_prefill_only:
1471
1472
  if self.running_batch.is_empty():
1472
1473
  self.running_batch = self.last_batch
1473
1474
  else:
@@ -1634,7 +1635,6 @@ class Scheduler(
1634
1635
  self.model_config,
1635
1636
  self.enable_overlap,
1636
1637
  self.spec_algorithm,
1637
- self.server_args.enable_custom_logit_processor,
1638
1638
  chunked_req=self.chunked_req,
1639
1639
  )
1640
1640
  if self.enable_hierarchical_cache:
@@ -2031,7 +2031,6 @@ class Scheduler(
2031
2031
  self.model_config,
2032
2032
  self.enable_overlap,
2033
2033
  self.spec_algorithm,
2034
- self.server_args.enable_custom_logit_processor,
2035
2034
  )
2036
2035
  idle_batch.prepare_for_idle()
2037
2036
  return idle_batch
@@ -699,7 +699,7 @@ class TokenizerManager:
699
699
  # Process all requests
700
700
  tokenized_objs = []
701
701
  for i, req in enumerate(requests):
702
- self._validate_token_len(obj[i], input_ids_list[i])
702
+ self._validate_one_request(obj[i], input_ids_list[i])
703
703
  tokenized_objs.append(
704
704
  self._create_tokenized_object(
705
705
  req, req.text, input_ids_list[i], None, None
@@ -1529,6 +1529,7 @@ class TokenizerManager:
1529
1529
  "id": rid,
1530
1530
  "finish_reason": recv_obj.finished_reasons[i],
1531
1531
  "prompt_tokens": recv_obj.prompt_tokens[i],
1532
+ "weight_version": self.server_args.weight_version,
1532
1533
  }
1533
1534
 
1534
1535
  if getattr(state.obj, "return_logprob", False):
@@ -1892,6 +1893,13 @@ class TokenizerManager:
1892
1893
  f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
1893
1894
  )
1894
1895
 
1896
+ batch_request = GenerateReqInput(
1897
+ token_ids_logprob=label_token_ids,
1898
+ return_logprob=True,
1899
+ stream=False,
1900
+ sampling_params={"max_new_tokens": 0},
1901
+ )
1902
+
1895
1903
  # Handle string or tokenized query/items
1896
1904
  if isinstance(query, str) and (
1897
1905
  isinstance(items, str)
@@ -1903,13 +1911,9 @@ class TokenizerManager:
1903
1911
  prompts = [f"{item}{query}" for item in items_list]
1904
1912
  else:
1905
1913
  prompts = [f"{query}{item}" for item in items_list]
1906
- batch_request = GenerateReqInput(
1907
- text=prompts,
1908
- return_logprob=True,
1909
- token_ids_logprob=label_token_ids,
1910
- stream=False,
1911
- sampling_params={"max_new_tokens": 1},
1912
- )
1914
+
1915
+ batch_request.text = prompts
1916
+
1913
1917
  elif (
1914
1918
  isinstance(query, list)
1915
1919
  and isinstance(items, list)
@@ -1921,13 +1925,8 @@ class TokenizerManager:
1921
1925
  input_ids_list = [item + query for item in items]
1922
1926
  else:
1923
1927
  input_ids_list = [query + item for item in items]
1924
- batch_request = GenerateReqInput(
1925
- input_ids=input_ids_list,
1926
- return_logprob=True,
1927
- token_ids_logprob=label_token_ids,
1928
- stream=False,
1929
- sampling_params={"max_new_tokens": 1},
1930
- )
1928
+
1929
+ batch_request.input_ids = input_ids_list
1931
1930
  else:
1932
1931
  raise ValueError(
1933
1932
  "Invalid combination of query/items types for score_request."
@@ -1939,9 +1938,20 @@ class TokenizerManager:
1939
1938
  for result in results:
1940
1939
  # Get logprobs for each token
1941
1940
  logprobs = {}
1942
- for logprob, token_id, _ in result["meta_info"].get(
1943
- "output_token_ids_logprobs", []
1944
- )[0]:
1941
+
1942
+ # For scoring requests, we read from output_token_ids_logprobs since we want
1943
+ # the logprobs for specific tokens mentioned in the label_token_ids at
1944
+ # the next position after the last token in the prompt
1945
+ output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
1946
+
1947
+ # Throw an error here if output_logprobs is None
1948
+ if output_logprobs is None:
1949
+ raise RuntimeError(
1950
+ f"output_logprobs is None for request {result['meta_info'].get('id', '<unknown>')}. "
1951
+ "This usually indicates a problem with the scoring request or the backend output."
1952
+ )
1953
+
1954
+ for logprob, token_id, _ in output_logprobs[0]:
1945
1955
  if token_id in label_token_ids:
1946
1956
  logprobs[token_id] = logprob
1947
1957
 
@@ -20,7 +20,6 @@ Page-aligned memory pool.
20
20
  """
21
21
 
22
22
  import abc
23
- import weakref
24
23
  from typing import TYPE_CHECKING
25
24
 
26
25
  import torch
@@ -81,9 +80,6 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
81
80
  if self.free_group:
82
81
  self.free(torch.cat(self.free_group))
83
82
 
84
- def estimated_num_new_pages(self, bs, extend_num_tokens):
85
- return bs * ((extend_num_tokens + self.page_size - 1) // self.page_size)
86
-
87
83
  def merge_and_sort_free(self):
88
84
  if len(self.release_pages) > 0:
89
85
  self.free_pages = torch.cat((self.free_pages, self.release_pages))
@@ -149,6 +145,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
149
145
  def alloc(self, need_size: int):
150
146
  if self.need_sort and need_size > len(self.free_pages):
151
147
  self.merge_and_sort_free()
148
+
152
149
  if need_size > len(self.free_pages):
153
150
  return None
154
151
 
@@ -437,9 +434,13 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
437
434
  device: str,
438
435
  kvcache: KVCache,
439
436
  need_sort: bool,
437
+ max_num_extend_tokens: int,
440
438
  ):
441
439
  super().__init__(size, page_size, dtype, device, kvcache, need_sort)
442
440
  self.num_pages = size // page_size
441
+ self.max_num_extend_tokens_next_power_of_2 = next_power_of_2(
442
+ max_num_extend_tokens
443
+ )
443
444
  self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
444
445
  self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
445
446
  self.clear()
@@ -480,7 +481,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
480
481
  )
481
482
 
482
483
  bs = len(prefix_lens)
483
- if self.need_sort and self.estimated_num_new_pages(bs, extend_num_tokens) > len(
484
+ if self.need_sort and extend_num_tokens // self.page_size + bs + 1 > len(
484
485
  self.free_pages
485
486
  ):
486
487
  self.merge_and_sort_free()
@@ -497,7 +498,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
497
498
  self.ret_values,
498
499
  next_power_of_2(bs),
499
500
  self.page_size,
500
- next_power_of_2(extend_num_tokens),
501
+ self.max_num_extend_tokens_next_power_of_2,
501
502
  )
502
503
 
503
504
  if self.debug_mode:
@@ -522,9 +523,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
522
523
  )
523
524
 
524
525
  bs = len(seq_lens)
525
- if self.need_sort and self.estimated_num_new_pages(bs, 1) > len(
526
- self.free_pages
527
- ):
526
+ if self.need_sort and bs > len(self.free_pages):
528
527
  self.merge_and_sort_free()
529
528
 
530
529
  out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
@@ -578,151 +577,3 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
578
577
 
579
578
  def load_cpu_copy(self, kv_cache_cpu, indices):
580
579
  return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
581
-
582
-
583
- def alloc_extend_kernel_ascend(
584
- prefix_lens,
585
- seq_lens,
586
- last_loc,
587
- free_pages,
588
- out_indices,
589
- page_size,
590
- device,
591
- ):
592
- extend_lens = seq_lens - prefix_lens
593
- end_pos = torch.cumsum(extend_lens, 0)
594
- start_pos = end_pos - extend_lens
595
- num_new_pages = (seq_lens + page_size - 1) // page_size - (
596
- prefix_lens + page_size - 1
597
- ) // page_size
598
- num_full_new_pages = (seq_lens) // page_size - (
599
- prefix_lens + page_size - 1
600
- ) // page_size
601
- need_page = num_new_pages - num_full_new_pages
602
- end_new_pages = torch.cumsum(num_new_pages, 0)
603
- start_new_pages = end_new_pages - num_new_pages
604
- pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
605
- for i in range(len(prefix_lens)):
606
- num1 = (
607
- min(
608
- seq_lens[i],
609
- (prefix_lens[i] + page_size - 1) // page_size * page_size,
610
- )
611
- - prefix_lens[i]
612
- )
613
- if num1:
614
- out_indices[start_pos[i] : start_pos[i] + num1] = (
615
- last_loc[i] + 1 + pos_in_page[:num1].view(-1)
616
- )
617
-
618
- num2 = (
619
- seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
620
- ) * page_size
621
- if num2:
622
- pages = (
623
- free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
624
- * page_size
625
- )
626
- out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
627
- pages.view(-1, 1) + pos_in_page.view(1, -1)
628
- ).view(-1)
629
-
630
- num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
631
- if num3:
632
- out_indices[end_pos[i] - num3 : end_pos[i]] = (
633
- free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
634
- ).view(-1)
635
-
636
-
637
- class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
638
-
639
- def __init__(
640
- self,
641
- size: int,
642
- page_size: int,
643
- dtype: torch.dtype,
644
- device: str,
645
- kvcache: KVCache,
646
- need_sort: bool,
647
- ):
648
- super().__init__(size, page_size, dtype, device, kvcache, need_sort)
649
-
650
- def alloc_extend(
651
- self,
652
- prefix_lens: torch.Tensor,
653
- seq_lens: torch.Tensor,
654
- last_loc: torch.Tensor,
655
- extend_num_tokens: int,
656
- ):
657
- if self.debug_mode:
658
- assert torch.all(
659
- (last_loc + 1) % self.page_size == prefix_lens % self.page_size
660
- )
661
-
662
- estimated_num_new_pages = (
663
- (
664
- (seq_lens + self.page_size - 1) // self.page_size
665
- - (prefix_lens + self.page_size - 1) // self.page_size
666
- )
667
- .sum()
668
- .item()
669
- )
670
- if self.need_sort and estimated_num_new_pages > len(self.free_pages):
671
- self.merge_and_sort_free()
672
-
673
- if estimated_num_new_pages > len(self.free_pages):
674
- return None
675
-
676
- out_indices = torch.empty(
677
- (extend_num_tokens,), dtype=torch.int32, device=self.device
678
- )
679
-
680
- alloc_extend_kernel_ascend(
681
- prefix_lens,
682
- seq_lens,
683
- last_loc,
684
- self.free_pages,
685
- out_indices,
686
- self.page_size,
687
- self.device,
688
- )
689
-
690
- if self.debug_mode:
691
- assert len(torch.unique(out_indices)) == len(out_indices)
692
-
693
- self.free_pages = self.free_pages[estimated_num_new_pages:]
694
- return out_indices
695
-
696
- def alloc_decode(
697
- self,
698
- seq_lens: torch.Tensor,
699
- last_loc: torch.Tensor,
700
- ):
701
- if self.debug_mode:
702
- assert torch.all(
703
- (last_loc + 2) % self.page_size == seq_lens % self.page_size
704
- )
705
-
706
- need_new_pages = (seq_lens % self.page_size == 1).int()
707
- num_new_pages = need_new_pages.sum().item()
708
-
709
- if num_new_pages > len(self.free_pages):
710
- self.merge_and_sort_free()
711
-
712
- if num_new_pages > len(self.free_pages):
713
- return None
714
-
715
- end_new_pages = torch.cumsum(need_new_pages, 0)
716
- start_new_pages = end_new_pages - need_new_pages
717
- if num_new_pages == 0:
718
- out_indices = last_loc + 1
719
- else:
720
- out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
721
- start_new_pages
722
- ] * self.page_size * need_new_pages
723
-
724
- if self.debug_mode:
725
- assert len(torch.unique(out_indices)) == len(out_indices)
726
-
727
- self.free_pages = self.free_pages[num_new_pages:]
728
- return out_indices.int()
@@ -0,0 +1,158 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import torch
6
+
7
+ from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator
8
+
9
+ if TYPE_CHECKING:
10
+ from sglang.srt.mem_cache.memory_pool import KVCache
11
+
12
+
13
+ def alloc_extend_kernel_ascend(
14
+ prefix_lens,
15
+ seq_lens,
16
+ last_loc,
17
+ free_pages,
18
+ out_indices,
19
+ page_size,
20
+ device,
21
+ ):
22
+ extend_lens = seq_lens - prefix_lens
23
+ end_pos = torch.cumsum(extend_lens, 0)
24
+ start_pos = end_pos - extend_lens
25
+ num_new_pages = (seq_lens + page_size - 1) // page_size - (
26
+ prefix_lens + page_size - 1
27
+ ) // page_size
28
+ num_full_new_pages = (seq_lens) // page_size - (
29
+ prefix_lens + page_size - 1
30
+ ) // page_size
31
+ need_page = num_new_pages - num_full_new_pages
32
+ end_new_pages = torch.cumsum(num_new_pages, 0)
33
+ start_new_pages = end_new_pages - num_new_pages
34
+ pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
35
+ for i in range(len(prefix_lens)):
36
+ num1 = (
37
+ min(
38
+ seq_lens[i],
39
+ (prefix_lens[i] + page_size - 1) // page_size * page_size,
40
+ )
41
+ - prefix_lens[i]
42
+ )
43
+ if num1:
44
+ out_indices[start_pos[i] : start_pos[i] + num1] = (
45
+ last_loc[i] + 1 + pos_in_page[:num1].view(-1)
46
+ )
47
+
48
+ num2 = (
49
+ seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
50
+ ) * page_size
51
+ if num2:
52
+ pages = (
53
+ free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
54
+ * page_size
55
+ )
56
+ out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
57
+ pages.view(-1, 1) + pos_in_page.view(1, -1)
58
+ ).view(-1)
59
+
60
+ num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
61
+ if num3:
62
+ out_indices[end_pos[i] - num3 : end_pos[i]] = (
63
+ free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
64
+ ).view(-1)
65
+
66
+
67
+ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
68
+
69
+ def __init__(
70
+ self,
71
+ size: int,
72
+ page_size: int,
73
+ dtype: torch.dtype,
74
+ device: str,
75
+ kvcache: KVCache,
76
+ need_sort: bool,
77
+ ):
78
+ super().__init__(size, page_size, dtype, device, kvcache, need_sort, 1)
79
+
80
+ def alloc_extend(
81
+ self,
82
+ prefix_lens: torch.Tensor,
83
+ seq_lens: torch.Tensor,
84
+ last_loc: torch.Tensor,
85
+ extend_num_tokens: int,
86
+ ):
87
+ if self.debug_mode:
88
+ assert torch.all(
89
+ (last_loc + 1) % self.page_size == prefix_lens % self.page_size
90
+ )
91
+
92
+ num_new_pages = (
93
+ (
94
+ (seq_lens + self.page_size - 1) // self.page_size
95
+ - (prefix_lens + self.page_size - 1) // self.page_size
96
+ )
97
+ .sum()
98
+ .item()
99
+ )
100
+ if self.need_sort and num_new_pages > len(self.free_pages):
101
+ self.merge_and_sort_free()
102
+
103
+ if num_new_pages > len(self.free_pages):
104
+ return None
105
+
106
+ out_indices = torch.empty(
107
+ (extend_num_tokens,), dtype=torch.int32, device=self.device
108
+ )
109
+
110
+ alloc_extend_kernel_ascend(
111
+ prefix_lens,
112
+ seq_lens,
113
+ last_loc,
114
+ self.free_pages,
115
+ out_indices,
116
+ self.page_size,
117
+ self.device,
118
+ )
119
+
120
+ if self.debug_mode:
121
+ assert len(torch.unique(out_indices)) == len(out_indices)
122
+
123
+ self.free_pages = self.free_pages[num_new_pages:]
124
+ return out_indices
125
+
126
+ def alloc_decode(
127
+ self,
128
+ seq_lens: torch.Tensor,
129
+ last_loc: torch.Tensor,
130
+ ):
131
+ if self.debug_mode:
132
+ assert torch.all(
133
+ (last_loc + 2) % self.page_size == seq_lens % self.page_size
134
+ )
135
+
136
+ need_new_pages = (seq_lens % self.page_size == 1).int()
137
+ num_new_pages = need_new_pages.sum().item()
138
+
139
+ if num_new_pages > len(self.free_pages):
140
+ self.merge_and_sort_free()
141
+
142
+ if num_new_pages > len(self.free_pages):
143
+ return None
144
+
145
+ end_new_pages = torch.cumsum(need_new_pages, 0)
146
+ start_new_pages = end_new_pages - need_new_pages
147
+ if num_new_pages == 0:
148
+ out_indices = last_loc + 1
149
+ else:
150
+ out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
151
+ start_new_pages
152
+ ] * self.page_size * need_new_pages
153
+
154
+ if self.debug_mode:
155
+ assert len(torch.unique(out_indices)) == len(out_indices)
156
+
157
+ self.free_pages = self.free_pages[num_new_pages:]
158
+ return out_indices.int()
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  """Cache for chunked prefill, used when RadixCache is disabled."""
4
4
 
5
- from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
5
+ from typing import TYPE_CHECKING, Any, Optional
6
6
 
7
7
  import torch
8
8