sglang 0.5.0rc1__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -1
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/disaggregation/decode.py +0 -1
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/entrypoints/http_server.py +64 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -0
- sglang/srt/entrypoints/openai/serving_chat.py +1 -0
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/layers/attention/flashinfer_backend.py +3 -0
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
- sglang/srt/layers/attention/triton_backend.py +24 -27
- sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +10 -3
- sglang/srt/layers/communicator.py +7 -7
- sglang/srt/layers/dp_attention.py +118 -27
- sglang/srt/layers/logits_processor.py +12 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +5 -32
- sglang/srt/layers/quantization/awq.py +15 -16
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
- sglang/srt/layers/quantization/gptq.py +12 -17
- sglang/srt/layers/quantization/marlin_utils.py +15 -5
- sglang/srt/layers/quantization/modelopt_quant.py +52 -30
- sglang/srt/layers/quantization/mxfp4.py +16 -2
- sglang/srt/layers/quantization/utils.py +52 -2
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/layers.py +6 -2
- sglang/srt/managers/cache_controller.py +4 -1
- sglang/srt/managers/io_struct.py +14 -0
- sglang/srt/managers/schedule_batch.py +18 -39
- sglang/srt/managers/scheduler.py +3 -4
- sglang/srt/managers/tokenizer_manager.py +28 -18
- sglang/srt/mem_cache/allocator.py +8 -157
- sglang/srt/mem_cache/allocator_ascend.py +158 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +8 -21
- sglang/srt/model_executor/forward_batch_info.py +8 -10
- sglang/srt/model_executor/model_runner.py +57 -53
- sglang/srt/models/deepseek_nextn.py +2 -1
- sglang/srt/models/deepseek_v2.py +5 -3
- sglang/srt/models/glm4_moe.py +2 -2
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/gpt_oss.py +7 -2
- sglang/srt/models/llama.py +10 -2
- sglang/srt/models/llama4.py +18 -5
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_moe.py +20 -5
- sglang/srt/models/qwen3_classification.py +78 -0
- sglang/srt/models/qwen3_moe.py +18 -5
- sglang/srt/models/step3_vl.py +6 -2
- sglang/srt/operations.py +17 -2
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +33 -7
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/two_batch_overlap.py +4 -8
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +5 -5
- {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +75 -63
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -84,7 +84,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
84
84
|
"device",
|
85
85
|
"disable_chunked_prefix_cache",
|
86
86
|
"disable_radix_cache",
|
87
|
-
"enable_dp_attention",
|
88
87
|
"enable_two_batch_overlap",
|
89
88
|
"tbo_token_distribution_threshold",
|
90
89
|
"enable_dp_lm_head",
|
@@ -113,6 +112,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
113
112
|
"enable_multimodal",
|
114
113
|
"enable_symm_mem",
|
115
114
|
"quantization",
|
115
|
+
"enable_custom_logit_processor",
|
116
116
|
]
|
117
117
|
|
118
118
|
# Put some global args for easy access
|
@@ -909,12 +909,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
909
909
|
spec_algorithm: SpeculativeAlgorithm = None
|
910
910
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
|
911
911
|
|
912
|
-
# Enable custom logit processor
|
913
|
-
enable_custom_logit_processor: bool = False
|
914
|
-
|
915
912
|
# Whether to return hidden states
|
916
913
|
return_hidden_states: bool = False
|
917
914
|
|
915
|
+
# Whether this batch is prefill-only (no token generation needed)
|
916
|
+
is_prefill_only: bool = False
|
917
|
+
|
918
918
|
# hicache pointer for synchronizing data loading from CPU to GPU
|
919
919
|
hicache_consumer_index: int = 0
|
920
920
|
|
@@ -928,7 +928,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
928
928
|
model_config: ModelConfig,
|
929
929
|
enable_overlap: bool,
|
930
930
|
spec_algorithm: SpeculativeAlgorithm,
|
931
|
-
enable_custom_logit_processor: bool,
|
932
931
|
chunked_req: Optional[Req] = None,
|
933
932
|
):
|
934
933
|
return_logprob = any(req.return_logprob for req in reqs)
|
@@ -955,8 +954,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
955
954
|
has_grammar=any(req.grammar for req in reqs),
|
956
955
|
device=req_to_token_pool.device,
|
957
956
|
spec_algorithm=spec_algorithm,
|
958
|
-
enable_custom_logit_processor=enable_custom_logit_processor,
|
959
957
|
return_hidden_states=any(req.return_hidden_states for req in reqs),
|
958
|
+
is_prefill_only=all(
|
959
|
+
req.sampling_params.max_new_tokens == 0 for req in reqs
|
960
|
+
),
|
960
961
|
chunked_req=chunked_req,
|
961
962
|
)
|
962
963
|
|
@@ -1009,6 +1010,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1009
1010
|
extend_num_tokens: int,
|
1010
1011
|
backup_state: bool = False,
|
1011
1012
|
):
|
1013
|
+
# Over estimate the number of tokens: assume each request needs a new page.
|
1012
1014
|
num_tokens = (
|
1013
1015
|
extend_num_tokens
|
1014
1016
|
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
@@ -1041,8 +1043,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1041
1043
|
last_loc: torch.Tensor,
|
1042
1044
|
backup_state: bool = False,
|
1043
1045
|
):
|
1046
|
+
# Over estimate the number of tokens: assume each request needs a new page.
|
1044
1047
|
num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
1045
|
-
|
1046
1048
|
self._evict_tree_cache_if_needed(num_tokens)
|
1047
1049
|
|
1048
1050
|
if backup_state:
|
@@ -1721,38 +1723,18 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1721
1723
|
extend_prefix_lens = self.prefix_lens
|
1722
1724
|
extend_logprob_start_lens = self.extend_logprob_start_lens
|
1723
1725
|
|
1724
|
-
if self.forward_mode.is_decode_or_idle():
|
1725
|
-
attention_backend_str = global_server_args_dict["decode_attention_backend"]
|
1726
|
-
else:
|
1727
|
-
attention_backend_str = global_server_args_dict["prefill_attention_backend"]
|
1728
|
-
# Create seq_lens_cpu when needed
|
1729
|
-
if (
|
1730
|
-
attention_backend_str
|
1731
|
-
in [
|
1732
|
-
"fa3",
|
1733
|
-
"flashinfer",
|
1734
|
-
"flashmla",
|
1735
|
-
"cutlass_mla",
|
1736
|
-
"ascend",
|
1737
|
-
"trtllm_mha",
|
1738
|
-
"aiter",
|
1739
|
-
]
|
1740
|
-
or global_server_args_dict["enable_two_batch_overlap"]
|
1741
|
-
):
|
1742
|
-
seq_lens_cpu = (
|
1743
|
-
seq_lens_cpu_cache
|
1744
|
-
if seq_lens_cpu_cache is not None
|
1745
|
-
else self.seq_lens.cpu()
|
1746
|
-
)
|
1747
|
-
else:
|
1748
|
-
seq_lens_cpu = None
|
1749
|
-
|
1750
1726
|
if self.sampling_info:
|
1751
1727
|
if self.has_grammar:
|
1752
1728
|
self.sampling_info.grammars = [req.grammar for req in self.reqs]
|
1753
1729
|
else:
|
1754
1730
|
self.sampling_info.grammars = None
|
1755
1731
|
|
1732
|
+
seq_lens_cpu = (
|
1733
|
+
seq_lens_cpu_cache
|
1734
|
+
if seq_lens_cpu_cache is not None
|
1735
|
+
else self.seq_lens.cpu()
|
1736
|
+
)
|
1737
|
+
|
1756
1738
|
global bid
|
1757
1739
|
bid += 1
|
1758
1740
|
return ModelWorkerBatch(
|
@@ -1815,18 +1797,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1815
1797
|
return_logprob=self.return_logprob,
|
1816
1798
|
decoding_reqs=self.decoding_reqs,
|
1817
1799
|
spec_algorithm=self.spec_algorithm,
|
1818
|
-
enable_custom_logit_processor=self.enable_custom_logit_processor,
|
1819
1800
|
global_num_tokens=self.global_num_tokens,
|
1820
1801
|
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
1821
1802
|
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
1822
1803
|
is_extend_in_batch=self.is_extend_in_batch,
|
1804
|
+
is_prefill_only=self.is_prefill_only,
|
1823
1805
|
)
|
1824
1806
|
|
1825
|
-
def _evict_tree_cache_if_needed(
|
1826
|
-
self,
|
1827
|
-
num_tokens: int,
|
1828
|
-
) -> None:
|
1829
|
-
if isinstance(self.tree_cache, SWAChunkCache):
|
1807
|
+
def _evict_tree_cache_if_needed(self, num_tokens: int):
|
1808
|
+
if isinstance(self.tree_cache, (SWAChunkCache, ChunkCache)):
|
1830
1809
|
return
|
1831
1810
|
|
1832
1811
|
if self.is_hybrid:
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -1466,8 +1466,9 @@ class Scheduler(
|
|
1466
1466
|
if self.last_batch.batch_size() < last_bs:
|
1467
1467
|
self.running_batch.batch_is_full = False
|
1468
1468
|
|
1469
|
-
# Merge the new batch into the running batch
|
1470
|
-
|
1469
|
+
# Merge the new batch into the running batch.
|
1470
|
+
# For prefill-only batch, we can avoid going through decoding step.
|
1471
|
+
if not self.last_batch.is_empty() and not self.last_batch.is_prefill_only:
|
1471
1472
|
if self.running_batch.is_empty():
|
1472
1473
|
self.running_batch = self.last_batch
|
1473
1474
|
else:
|
@@ -1634,7 +1635,6 @@ class Scheduler(
|
|
1634
1635
|
self.model_config,
|
1635
1636
|
self.enable_overlap,
|
1636
1637
|
self.spec_algorithm,
|
1637
|
-
self.server_args.enable_custom_logit_processor,
|
1638
1638
|
chunked_req=self.chunked_req,
|
1639
1639
|
)
|
1640
1640
|
if self.enable_hierarchical_cache:
|
@@ -2031,7 +2031,6 @@ class Scheduler(
|
|
2031
2031
|
self.model_config,
|
2032
2032
|
self.enable_overlap,
|
2033
2033
|
self.spec_algorithm,
|
2034
|
-
self.server_args.enable_custom_logit_processor,
|
2035
2034
|
)
|
2036
2035
|
idle_batch.prepare_for_idle()
|
2037
2036
|
return idle_batch
|
@@ -699,7 +699,7 @@ class TokenizerManager:
|
|
699
699
|
# Process all requests
|
700
700
|
tokenized_objs = []
|
701
701
|
for i, req in enumerate(requests):
|
702
|
-
self.
|
702
|
+
self._validate_one_request(obj[i], input_ids_list[i])
|
703
703
|
tokenized_objs.append(
|
704
704
|
self._create_tokenized_object(
|
705
705
|
req, req.text, input_ids_list[i], None, None
|
@@ -1529,6 +1529,7 @@ class TokenizerManager:
|
|
1529
1529
|
"id": rid,
|
1530
1530
|
"finish_reason": recv_obj.finished_reasons[i],
|
1531
1531
|
"prompt_tokens": recv_obj.prompt_tokens[i],
|
1532
|
+
"weight_version": self.server_args.weight_version,
|
1532
1533
|
}
|
1533
1534
|
|
1534
1535
|
if getattr(state.obj, "return_logprob", False):
|
@@ -1892,6 +1893,13 @@ class TokenizerManager:
|
|
1892
1893
|
f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
|
1893
1894
|
)
|
1894
1895
|
|
1896
|
+
batch_request = GenerateReqInput(
|
1897
|
+
token_ids_logprob=label_token_ids,
|
1898
|
+
return_logprob=True,
|
1899
|
+
stream=False,
|
1900
|
+
sampling_params={"max_new_tokens": 0},
|
1901
|
+
)
|
1902
|
+
|
1895
1903
|
# Handle string or tokenized query/items
|
1896
1904
|
if isinstance(query, str) and (
|
1897
1905
|
isinstance(items, str)
|
@@ -1903,13 +1911,9 @@ class TokenizerManager:
|
|
1903
1911
|
prompts = [f"{item}{query}" for item in items_list]
|
1904
1912
|
else:
|
1905
1913
|
prompts = [f"{query}{item}" for item in items_list]
|
1906
|
-
|
1907
|
-
|
1908
|
-
|
1909
|
-
token_ids_logprob=label_token_ids,
|
1910
|
-
stream=False,
|
1911
|
-
sampling_params={"max_new_tokens": 1},
|
1912
|
-
)
|
1914
|
+
|
1915
|
+
batch_request.text = prompts
|
1916
|
+
|
1913
1917
|
elif (
|
1914
1918
|
isinstance(query, list)
|
1915
1919
|
and isinstance(items, list)
|
@@ -1921,13 +1925,8 @@ class TokenizerManager:
|
|
1921
1925
|
input_ids_list = [item + query for item in items]
|
1922
1926
|
else:
|
1923
1927
|
input_ids_list = [query + item for item in items]
|
1924
|
-
|
1925
|
-
|
1926
|
-
return_logprob=True,
|
1927
|
-
token_ids_logprob=label_token_ids,
|
1928
|
-
stream=False,
|
1929
|
-
sampling_params={"max_new_tokens": 1},
|
1930
|
-
)
|
1928
|
+
|
1929
|
+
batch_request.input_ids = input_ids_list
|
1931
1930
|
else:
|
1932
1931
|
raise ValueError(
|
1933
1932
|
"Invalid combination of query/items types for score_request."
|
@@ -1939,9 +1938,20 @@ class TokenizerManager:
|
|
1939
1938
|
for result in results:
|
1940
1939
|
# Get logprobs for each token
|
1941
1940
|
logprobs = {}
|
1942
|
-
|
1943
|
-
|
1944
|
-
|
1941
|
+
|
1942
|
+
# For scoring requests, we read from output_token_ids_logprobs since we want
|
1943
|
+
# the logprobs for specific tokens mentioned in the label_token_ids at
|
1944
|
+
# the next position after the last token in the prompt
|
1945
|
+
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
|
1946
|
+
|
1947
|
+
# Throw an error here if output_logprobs is None
|
1948
|
+
if output_logprobs is None:
|
1949
|
+
raise RuntimeError(
|
1950
|
+
f"output_logprobs is None for request {result['meta_info'].get('id', '<unknown>')}. "
|
1951
|
+
"This usually indicates a problem with the scoring request or the backend output."
|
1952
|
+
)
|
1953
|
+
|
1954
|
+
for logprob, token_id, _ in output_logprobs[0]:
|
1945
1955
|
if token_id in label_token_ids:
|
1946
1956
|
logprobs[token_id] = logprob
|
1947
1957
|
|
@@ -20,7 +20,6 @@ Page-aligned memory pool.
|
|
20
20
|
"""
|
21
21
|
|
22
22
|
import abc
|
23
|
-
import weakref
|
24
23
|
from typing import TYPE_CHECKING
|
25
24
|
|
26
25
|
import torch
|
@@ -81,9 +80,6 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
|
81
80
|
if self.free_group:
|
82
81
|
self.free(torch.cat(self.free_group))
|
83
82
|
|
84
|
-
def estimated_num_new_pages(self, bs, extend_num_tokens):
|
85
|
-
return bs * ((extend_num_tokens + self.page_size - 1) // self.page_size)
|
86
|
-
|
87
83
|
def merge_and_sort_free(self):
|
88
84
|
if len(self.release_pages) > 0:
|
89
85
|
self.free_pages = torch.cat((self.free_pages, self.release_pages))
|
@@ -149,6 +145,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
149
145
|
def alloc(self, need_size: int):
|
150
146
|
if self.need_sort and need_size > len(self.free_pages):
|
151
147
|
self.merge_and_sort_free()
|
148
|
+
|
152
149
|
if need_size > len(self.free_pages):
|
153
150
|
return None
|
154
151
|
|
@@ -437,9 +434,13 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
437
434
|
device: str,
|
438
435
|
kvcache: KVCache,
|
439
436
|
need_sort: bool,
|
437
|
+
max_num_extend_tokens: int,
|
440
438
|
):
|
441
439
|
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
442
440
|
self.num_pages = size // page_size
|
441
|
+
self.max_num_extend_tokens_next_power_of_2 = next_power_of_2(
|
442
|
+
max_num_extend_tokens
|
443
|
+
)
|
443
444
|
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
444
445
|
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
445
446
|
self.clear()
|
@@ -480,7 +481,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
480
481
|
)
|
481
482
|
|
482
483
|
bs = len(prefix_lens)
|
483
|
-
if self.need_sort and self.
|
484
|
+
if self.need_sort and extend_num_tokens // self.page_size + bs + 1 > len(
|
484
485
|
self.free_pages
|
485
486
|
):
|
486
487
|
self.merge_and_sort_free()
|
@@ -497,7 +498,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
497
498
|
self.ret_values,
|
498
499
|
next_power_of_2(bs),
|
499
500
|
self.page_size,
|
500
|
-
|
501
|
+
self.max_num_extend_tokens_next_power_of_2,
|
501
502
|
)
|
502
503
|
|
503
504
|
if self.debug_mode:
|
@@ -522,9 +523,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
522
523
|
)
|
523
524
|
|
524
525
|
bs = len(seq_lens)
|
525
|
-
if self.need_sort and
|
526
|
-
self.free_pages
|
527
|
-
):
|
526
|
+
if self.need_sort and bs > len(self.free_pages):
|
528
527
|
self.merge_and_sort_free()
|
529
528
|
|
530
529
|
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
@@ -578,151 +577,3 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
578
577
|
|
579
578
|
def load_cpu_copy(self, kv_cache_cpu, indices):
|
580
579
|
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
|
581
|
-
|
582
|
-
|
583
|
-
def alloc_extend_kernel_ascend(
|
584
|
-
prefix_lens,
|
585
|
-
seq_lens,
|
586
|
-
last_loc,
|
587
|
-
free_pages,
|
588
|
-
out_indices,
|
589
|
-
page_size,
|
590
|
-
device,
|
591
|
-
):
|
592
|
-
extend_lens = seq_lens - prefix_lens
|
593
|
-
end_pos = torch.cumsum(extend_lens, 0)
|
594
|
-
start_pos = end_pos - extend_lens
|
595
|
-
num_new_pages = (seq_lens + page_size - 1) // page_size - (
|
596
|
-
prefix_lens + page_size - 1
|
597
|
-
) // page_size
|
598
|
-
num_full_new_pages = (seq_lens) // page_size - (
|
599
|
-
prefix_lens + page_size - 1
|
600
|
-
) // page_size
|
601
|
-
need_page = num_new_pages - num_full_new_pages
|
602
|
-
end_new_pages = torch.cumsum(num_new_pages, 0)
|
603
|
-
start_new_pages = end_new_pages - num_new_pages
|
604
|
-
pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
|
605
|
-
for i in range(len(prefix_lens)):
|
606
|
-
num1 = (
|
607
|
-
min(
|
608
|
-
seq_lens[i],
|
609
|
-
(prefix_lens[i] + page_size - 1) // page_size * page_size,
|
610
|
-
)
|
611
|
-
- prefix_lens[i]
|
612
|
-
)
|
613
|
-
if num1:
|
614
|
-
out_indices[start_pos[i] : start_pos[i] + num1] = (
|
615
|
-
last_loc[i] + 1 + pos_in_page[:num1].view(-1)
|
616
|
-
)
|
617
|
-
|
618
|
-
num2 = (
|
619
|
-
seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
|
620
|
-
) * page_size
|
621
|
-
if num2:
|
622
|
-
pages = (
|
623
|
-
free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
|
624
|
-
* page_size
|
625
|
-
)
|
626
|
-
out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
|
627
|
-
pages.view(-1, 1) + pos_in_page.view(1, -1)
|
628
|
-
).view(-1)
|
629
|
-
|
630
|
-
num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
|
631
|
-
if num3:
|
632
|
-
out_indices[end_pos[i] - num3 : end_pos[i]] = (
|
633
|
-
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
|
634
|
-
).view(-1)
|
635
|
-
|
636
|
-
|
637
|
-
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
638
|
-
|
639
|
-
def __init__(
|
640
|
-
self,
|
641
|
-
size: int,
|
642
|
-
page_size: int,
|
643
|
-
dtype: torch.dtype,
|
644
|
-
device: str,
|
645
|
-
kvcache: KVCache,
|
646
|
-
need_sort: bool,
|
647
|
-
):
|
648
|
-
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
649
|
-
|
650
|
-
def alloc_extend(
|
651
|
-
self,
|
652
|
-
prefix_lens: torch.Tensor,
|
653
|
-
seq_lens: torch.Tensor,
|
654
|
-
last_loc: torch.Tensor,
|
655
|
-
extend_num_tokens: int,
|
656
|
-
):
|
657
|
-
if self.debug_mode:
|
658
|
-
assert torch.all(
|
659
|
-
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
660
|
-
)
|
661
|
-
|
662
|
-
estimated_num_new_pages = (
|
663
|
-
(
|
664
|
-
(seq_lens + self.page_size - 1) // self.page_size
|
665
|
-
- (prefix_lens + self.page_size - 1) // self.page_size
|
666
|
-
)
|
667
|
-
.sum()
|
668
|
-
.item()
|
669
|
-
)
|
670
|
-
if self.need_sort and estimated_num_new_pages > len(self.free_pages):
|
671
|
-
self.merge_and_sort_free()
|
672
|
-
|
673
|
-
if estimated_num_new_pages > len(self.free_pages):
|
674
|
-
return None
|
675
|
-
|
676
|
-
out_indices = torch.empty(
|
677
|
-
(extend_num_tokens,), dtype=torch.int32, device=self.device
|
678
|
-
)
|
679
|
-
|
680
|
-
alloc_extend_kernel_ascend(
|
681
|
-
prefix_lens,
|
682
|
-
seq_lens,
|
683
|
-
last_loc,
|
684
|
-
self.free_pages,
|
685
|
-
out_indices,
|
686
|
-
self.page_size,
|
687
|
-
self.device,
|
688
|
-
)
|
689
|
-
|
690
|
-
if self.debug_mode:
|
691
|
-
assert len(torch.unique(out_indices)) == len(out_indices)
|
692
|
-
|
693
|
-
self.free_pages = self.free_pages[estimated_num_new_pages:]
|
694
|
-
return out_indices
|
695
|
-
|
696
|
-
def alloc_decode(
|
697
|
-
self,
|
698
|
-
seq_lens: torch.Tensor,
|
699
|
-
last_loc: torch.Tensor,
|
700
|
-
):
|
701
|
-
if self.debug_mode:
|
702
|
-
assert torch.all(
|
703
|
-
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
704
|
-
)
|
705
|
-
|
706
|
-
need_new_pages = (seq_lens % self.page_size == 1).int()
|
707
|
-
num_new_pages = need_new_pages.sum().item()
|
708
|
-
|
709
|
-
if num_new_pages > len(self.free_pages):
|
710
|
-
self.merge_and_sort_free()
|
711
|
-
|
712
|
-
if num_new_pages > len(self.free_pages):
|
713
|
-
return None
|
714
|
-
|
715
|
-
end_new_pages = torch.cumsum(need_new_pages, 0)
|
716
|
-
start_new_pages = end_new_pages - need_new_pages
|
717
|
-
if num_new_pages == 0:
|
718
|
-
out_indices = last_loc + 1
|
719
|
-
else:
|
720
|
-
out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
|
721
|
-
start_new_pages
|
722
|
-
] * self.page_size * need_new_pages
|
723
|
-
|
724
|
-
if self.debug_mode:
|
725
|
-
assert len(torch.unique(out_indices)) == len(out_indices)
|
726
|
-
|
727
|
-
self.free_pages = self.free_pages[num_new_pages:]
|
728
|
-
return out_indices.int()
|
@@ -0,0 +1,158 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from sglang.srt.mem_cache.memory_pool import KVCache
|
11
|
+
|
12
|
+
|
13
|
+
def alloc_extend_kernel_ascend(
|
14
|
+
prefix_lens,
|
15
|
+
seq_lens,
|
16
|
+
last_loc,
|
17
|
+
free_pages,
|
18
|
+
out_indices,
|
19
|
+
page_size,
|
20
|
+
device,
|
21
|
+
):
|
22
|
+
extend_lens = seq_lens - prefix_lens
|
23
|
+
end_pos = torch.cumsum(extend_lens, 0)
|
24
|
+
start_pos = end_pos - extend_lens
|
25
|
+
num_new_pages = (seq_lens + page_size - 1) // page_size - (
|
26
|
+
prefix_lens + page_size - 1
|
27
|
+
) // page_size
|
28
|
+
num_full_new_pages = (seq_lens) // page_size - (
|
29
|
+
prefix_lens + page_size - 1
|
30
|
+
) // page_size
|
31
|
+
need_page = num_new_pages - num_full_new_pages
|
32
|
+
end_new_pages = torch.cumsum(num_new_pages, 0)
|
33
|
+
start_new_pages = end_new_pages - num_new_pages
|
34
|
+
pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
|
35
|
+
for i in range(len(prefix_lens)):
|
36
|
+
num1 = (
|
37
|
+
min(
|
38
|
+
seq_lens[i],
|
39
|
+
(prefix_lens[i] + page_size - 1) // page_size * page_size,
|
40
|
+
)
|
41
|
+
- prefix_lens[i]
|
42
|
+
)
|
43
|
+
if num1:
|
44
|
+
out_indices[start_pos[i] : start_pos[i] + num1] = (
|
45
|
+
last_loc[i] + 1 + pos_in_page[:num1].view(-1)
|
46
|
+
)
|
47
|
+
|
48
|
+
num2 = (
|
49
|
+
seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
|
50
|
+
) * page_size
|
51
|
+
if num2:
|
52
|
+
pages = (
|
53
|
+
free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
|
54
|
+
* page_size
|
55
|
+
)
|
56
|
+
out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
|
57
|
+
pages.view(-1, 1) + pos_in_page.view(1, -1)
|
58
|
+
).view(-1)
|
59
|
+
|
60
|
+
num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
|
61
|
+
if num3:
|
62
|
+
out_indices[end_pos[i] - num3 : end_pos[i]] = (
|
63
|
+
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
|
64
|
+
).view(-1)
|
65
|
+
|
66
|
+
|
67
|
+
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
68
|
+
|
69
|
+
def __init__(
|
70
|
+
self,
|
71
|
+
size: int,
|
72
|
+
page_size: int,
|
73
|
+
dtype: torch.dtype,
|
74
|
+
device: str,
|
75
|
+
kvcache: KVCache,
|
76
|
+
need_sort: bool,
|
77
|
+
):
|
78
|
+
super().__init__(size, page_size, dtype, device, kvcache, need_sort, 1)
|
79
|
+
|
80
|
+
def alloc_extend(
|
81
|
+
self,
|
82
|
+
prefix_lens: torch.Tensor,
|
83
|
+
seq_lens: torch.Tensor,
|
84
|
+
last_loc: torch.Tensor,
|
85
|
+
extend_num_tokens: int,
|
86
|
+
):
|
87
|
+
if self.debug_mode:
|
88
|
+
assert torch.all(
|
89
|
+
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
90
|
+
)
|
91
|
+
|
92
|
+
num_new_pages = (
|
93
|
+
(
|
94
|
+
(seq_lens + self.page_size - 1) // self.page_size
|
95
|
+
- (prefix_lens + self.page_size - 1) // self.page_size
|
96
|
+
)
|
97
|
+
.sum()
|
98
|
+
.item()
|
99
|
+
)
|
100
|
+
if self.need_sort and num_new_pages > len(self.free_pages):
|
101
|
+
self.merge_and_sort_free()
|
102
|
+
|
103
|
+
if num_new_pages > len(self.free_pages):
|
104
|
+
return None
|
105
|
+
|
106
|
+
out_indices = torch.empty(
|
107
|
+
(extend_num_tokens,), dtype=torch.int32, device=self.device
|
108
|
+
)
|
109
|
+
|
110
|
+
alloc_extend_kernel_ascend(
|
111
|
+
prefix_lens,
|
112
|
+
seq_lens,
|
113
|
+
last_loc,
|
114
|
+
self.free_pages,
|
115
|
+
out_indices,
|
116
|
+
self.page_size,
|
117
|
+
self.device,
|
118
|
+
)
|
119
|
+
|
120
|
+
if self.debug_mode:
|
121
|
+
assert len(torch.unique(out_indices)) == len(out_indices)
|
122
|
+
|
123
|
+
self.free_pages = self.free_pages[num_new_pages:]
|
124
|
+
return out_indices
|
125
|
+
|
126
|
+
def alloc_decode(
|
127
|
+
self,
|
128
|
+
seq_lens: torch.Tensor,
|
129
|
+
last_loc: torch.Tensor,
|
130
|
+
):
|
131
|
+
if self.debug_mode:
|
132
|
+
assert torch.all(
|
133
|
+
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
134
|
+
)
|
135
|
+
|
136
|
+
need_new_pages = (seq_lens % self.page_size == 1).int()
|
137
|
+
num_new_pages = need_new_pages.sum().item()
|
138
|
+
|
139
|
+
if num_new_pages > len(self.free_pages):
|
140
|
+
self.merge_and_sort_free()
|
141
|
+
|
142
|
+
if num_new_pages > len(self.free_pages):
|
143
|
+
return None
|
144
|
+
|
145
|
+
end_new_pages = torch.cumsum(need_new_pages, 0)
|
146
|
+
start_new_pages = end_new_pages - need_new_pages
|
147
|
+
if num_new_pages == 0:
|
148
|
+
out_indices = last_loc + 1
|
149
|
+
else:
|
150
|
+
out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
|
151
|
+
start_new_pages
|
152
|
+
] * self.page_size * need_new_pages
|
153
|
+
|
154
|
+
if self.debug_mode:
|
155
|
+
assert len(torch.unique(out_indices)) == len(out_indices)
|
156
|
+
|
157
|
+
self.free_pages = self.free_pages[num_new_pages:]
|
158
|
+
return out_indices.int()
|