sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -2
- sglang/bench_serving.py +224 -127
- sglang/compile_deep_gemm.py +3 -0
- sglang/launch_server.py +0 -14
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/falcon_h1.py +12 -58
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +68 -31
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +11 -43
- sglang/srt/disaggregation/decode.py +7 -18
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/nixl/conn.py +55 -23
- sglang/srt/disaggregation/prefill.py +17 -32
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/entrypoints/grpc_request_manager.py +10 -23
- sglang/srt/entrypoints/grpc_server.py +220 -80
- sglang/srt/entrypoints/http_server.py +49 -1
- sglang/srt/entrypoints/openai/protocol.py +159 -31
- sglang/srt/entrypoints/openai/serving_chat.py +13 -71
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +4 -0
- sglang/srt/function_call/function_call_parser.py +8 -6
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
- sglang/srt/layers/attention/attention_registry.py +31 -22
- sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
- sglang/srt/layers/attention/flashattention_backend.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +223 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/triton_backend.py +1 -1
- sglang/srt/layers/logits_processor.py +136 -6
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
- sglang/srt/layers/moe/ep_moe/layer.py +8 -286
- sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/utils.py +7 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/modelopt_quant.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/w4afp8.py +2 -16
- sglang/srt/lora/lora_manager.py +0 -8
- sglang/srt/managers/overlap_utils.py +18 -16
- sglang/srt/managers/schedule_batch.py +119 -90
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +213 -126
- sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
- sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
- sglang/srt/managers/tokenizer_manager.py +270 -53
- sglang/srt/managers/tp_worker.py +39 -28
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +162 -68
- sglang/srt/mem_cache/radix_cache.py +8 -3
- sglang/srt/mem_cache/swa_radix_cache.py +70 -14
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/forward_batch_info.py +4 -18
- sglang/srt/model_executor/model_runner.py +55 -51
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +187 -6
- sglang/srt/model_loader/weight_utils.py +3 -0
- sglang/srt/models/falcon_h1.py +11 -9
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +11 -1
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/utils.py +5 -1
- sglang/srt/sampling/sampling_batch_info.py +11 -9
- sglang/srt/server_args.py +100 -33
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_utils.py +0 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils/common.py +18 -0
- sglang/srt/utils/hf_transformers_utils.py +2 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +40 -0
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +18 -2
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +63 -0
- sglang/test/test_utils.py +32 -11
- sglang/version.py +1 -1
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -182,6 +182,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
182
182
|
if speculative_algorithm.is_none()
|
183
183
|
else server_args.speculative_num_draft_tokens
|
184
184
|
)
|
185
|
+
# Initialize delimiter text for multi-item scoring (will be set after tokenizer is loaded)
|
186
|
+
self.multi_item_delimiter_text = None
|
185
187
|
|
186
188
|
if self.model_config.is_multimodal:
|
187
189
|
import_processors("sglang.srt.multimodal.processors")
|
@@ -223,6 +225,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
223
225
|
self.processor = _processor
|
224
226
|
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
225
227
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
228
|
+
self._initialize_multi_item_delimiter_text()
|
226
229
|
else:
|
227
230
|
self.mm_processor = self.processor = None
|
228
231
|
|
@@ -235,6 +238,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
235
238
|
trust_remote_code=server_args.trust_remote_code,
|
236
239
|
revision=server_args.revision,
|
237
240
|
)
|
241
|
+
self._initialize_multi_item_delimiter_text()
|
238
242
|
# Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
|
239
243
|
if (
|
240
244
|
server_args.enable_dynamic_batch_tokenizer
|
@@ -1678,6 +1682,201 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
1678
1682
|
if len(self.model_update_tmp) == self.server_args.dp_size:
|
1679
1683
|
self.model_update_result.set_result(self.model_update_tmp)
|
1680
1684
|
|
1685
|
+
def _initialize_multi_item_delimiter_text(self):
|
1686
|
+
"""Initialize multi-item delimiter text from token ID after tokenizer is loaded."""
|
1687
|
+
if (
|
1688
|
+
hasattr(self.server_args, "multi_item_scoring_delimiter")
|
1689
|
+
and self.server_args.multi_item_scoring_delimiter is not None
|
1690
|
+
and self.tokenizer is not None
|
1691
|
+
):
|
1692
|
+
try:
|
1693
|
+
self.multi_item_delimiter_text = self.tokenizer.decode(
|
1694
|
+
[self.server_args.multi_item_scoring_delimiter],
|
1695
|
+
skip_special_tokens=False,
|
1696
|
+
)
|
1697
|
+
except Exception as e:
|
1698
|
+
logger.warning(
|
1699
|
+
f"Failed to decode delimiter token {self.server_args.multi_item_scoring_delimiter}: {e}"
|
1700
|
+
)
|
1701
|
+
self.multi_item_delimiter_text = None
|
1702
|
+
|
1703
|
+
def _build_multi_item_token_sequence(
|
1704
|
+
self, query: List[int], items: List[List[int]], delimiter_token_id: int
|
1705
|
+
) -> List[int]:
|
1706
|
+
"""
|
1707
|
+
Build a single token sequence for multi-item scoring.
|
1708
|
+
Format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
|
1709
|
+
|
1710
|
+
Args:
|
1711
|
+
query: Query token IDs
|
1712
|
+
items: List of item token ID sequences
|
1713
|
+
delimiter_token_id: Token ID to use as delimiter
|
1714
|
+
|
1715
|
+
Returns:
|
1716
|
+
Combined token sequence
|
1717
|
+
"""
|
1718
|
+
combined_sequence = query[:] # Start with query
|
1719
|
+
|
1720
|
+
for item in items:
|
1721
|
+
combined_sequence.append(delimiter_token_id) # Add delimiter
|
1722
|
+
combined_sequence.extend(item) # Add item tokens
|
1723
|
+
|
1724
|
+
# Add final delimiter after the last item for logprob extraction
|
1725
|
+
combined_sequence.append(delimiter_token_id)
|
1726
|
+
|
1727
|
+
return combined_sequence
|
1728
|
+
|
1729
|
+
def _extract_logprobs_for_tokens(
|
1730
|
+
self, logprobs_data: List, label_token_ids: List[int]
|
1731
|
+
) -> Dict[int, float]:
|
1732
|
+
"""
|
1733
|
+
Extract logprobs for specified token IDs from logprobs data.
|
1734
|
+
|
1735
|
+
Args:
|
1736
|
+
logprobs_data: List of (logprob, token_id, text) tuples
|
1737
|
+
label_token_ids: Token IDs to extract logprobs for
|
1738
|
+
|
1739
|
+
Returns:
|
1740
|
+
Dictionary mapping token_id to logprob
|
1741
|
+
"""
|
1742
|
+
logprobs = {}
|
1743
|
+
if logprobs_data:
|
1744
|
+
for logprob, token_id, _ in logprobs_data:
|
1745
|
+
if token_id in label_token_ids:
|
1746
|
+
logprobs[token_id] = logprob
|
1747
|
+
return logprobs
|
1748
|
+
|
1749
|
+
def _convert_logprobs_to_scores(
|
1750
|
+
self,
|
1751
|
+
logprobs: Dict[int, float],
|
1752
|
+
label_token_ids: List[int],
|
1753
|
+
apply_softmax: bool,
|
1754
|
+
) -> List[float]:
|
1755
|
+
"""
|
1756
|
+
Convert logprobs dictionary to ordered score list.
|
1757
|
+
|
1758
|
+
Args:
|
1759
|
+
logprobs: Dictionary mapping token_id to logprob
|
1760
|
+
label_token_ids: Token IDs in desired order
|
1761
|
+
apply_softmax: Whether to apply softmax normalization
|
1762
|
+
|
1763
|
+
Returns:
|
1764
|
+
List of scores in the same order as label_token_ids
|
1765
|
+
"""
|
1766
|
+
score_list = [
|
1767
|
+
logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
|
1768
|
+
]
|
1769
|
+
|
1770
|
+
if apply_softmax:
|
1771
|
+
score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
|
1772
|
+
else:
|
1773
|
+
# Convert logprobs to probabilities if not using softmax
|
1774
|
+
score_list = [
|
1775
|
+
math.exp(x) if x != float("-inf") else 0.0 for x in score_list
|
1776
|
+
]
|
1777
|
+
|
1778
|
+
return score_list
|
1779
|
+
|
1780
|
+
def _process_multi_item_scoring_results(
|
1781
|
+
self,
|
1782
|
+
results: Any,
|
1783
|
+
items: List,
|
1784
|
+
label_token_ids: List[int],
|
1785
|
+
apply_softmax: bool,
|
1786
|
+
batch_request=None,
|
1787
|
+
) -> List[List[float]]:
|
1788
|
+
"""
|
1789
|
+
Process results from multi-item scoring request.
|
1790
|
+
Extracts logprobs at delimiter positions from input_token_ids_logprobs.
|
1791
|
+
|
1792
|
+
Args:
|
1793
|
+
results: Results from generate_request
|
1794
|
+
items: List of items being scored
|
1795
|
+
label_token_ids: Token IDs to extract scores for
|
1796
|
+
apply_softmax: Whether to apply softmax normalization
|
1797
|
+
batch_request: The original batch request containing input sequence
|
1798
|
+
|
1799
|
+
Returns:
|
1800
|
+
List of score lists, one for each item
|
1801
|
+
"""
|
1802
|
+
single_result = results[0] if isinstance(results, list) else results
|
1803
|
+
|
1804
|
+
# For multi-item scoring, logprobs are in input_token_ids_logprobs
|
1805
|
+
input_logprobs = single_result["meta_info"].get("input_token_ids_logprobs", [])
|
1806
|
+
|
1807
|
+
if not input_logprobs:
|
1808
|
+
raise RuntimeError(
|
1809
|
+
f"input_token_ids_logprobs is empty for multi-item scoring request {single_result['meta_info'].get('id', '<unknown>')}. "
|
1810
|
+
"This indicates token_ids_logprobs were not computed properly for Mutil Item Scoring."
|
1811
|
+
)
|
1812
|
+
|
1813
|
+
scores = []
|
1814
|
+
num_items = len(items) if isinstance(items, list) else 1
|
1815
|
+
|
1816
|
+
# Check if we have the expected number of logprobs
|
1817
|
+
expected_logprobs_count = num_items + 1
|
1818
|
+
if len(input_logprobs) != expected_logprobs_count:
|
1819
|
+
raise RuntimeError(
|
1820
|
+
f"Expected {expected_logprobs_count} input_token_ids_logprobs for multi-item scoring "
|
1821
|
+
f"with {num_items} items, but got {len(input_logprobs)}. "
|
1822
|
+
f"Request ID: {single_result['meta_info'].get('id', '<unknown>')}"
|
1823
|
+
)
|
1824
|
+
|
1825
|
+
# Skip the first delimiter (between query and first item) and process remaining delimiter positions
|
1826
|
+
# We want to exclude the first one since it represents the boundary between query and first item, not an item boundary
|
1827
|
+
start_idx = 1 if len(input_logprobs) > 1 else 0
|
1828
|
+
|
1829
|
+
# Process logprobs for each item position (excluding first delimiter)
|
1830
|
+
for item_idx in range(num_items):
|
1831
|
+
logprob_idx = start_idx + item_idx
|
1832
|
+
item_logprobs_data = input_logprobs[logprob_idx]
|
1833
|
+
logprobs = self._extract_logprobs_for_tokens(
|
1834
|
+
item_logprobs_data, label_token_ids
|
1835
|
+
)
|
1836
|
+
score_list = self._convert_logprobs_to_scores(
|
1837
|
+
logprobs, label_token_ids, apply_softmax
|
1838
|
+
)
|
1839
|
+
scores.append(score_list)
|
1840
|
+
|
1841
|
+
return scores
|
1842
|
+
|
1843
|
+
def _process_single_item_scoring_results(
|
1844
|
+
self, results: Any, label_token_ids: List[int], apply_softmax: bool
|
1845
|
+
) -> List[List[float]]:
|
1846
|
+
"""
|
1847
|
+
Process results from single-item scoring request.
|
1848
|
+
Single-item scoring results are stored in output_token_ids_logprobs.
|
1849
|
+
|
1850
|
+
Args:
|
1851
|
+
results: Results from generate_request
|
1852
|
+
label_token_ids: Token IDs to extract scores for
|
1853
|
+
apply_softmax: Whether to apply softmax normalization
|
1854
|
+
|
1855
|
+
Returns:
|
1856
|
+
List of score lists, one for each result
|
1857
|
+
"""
|
1858
|
+
scores = []
|
1859
|
+
|
1860
|
+
for result in results:
|
1861
|
+
# For single-item scoring, logprobs are in output_token_ids_logprobs
|
1862
|
+
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
|
1863
|
+
|
1864
|
+
if not output_logprobs or len(output_logprobs) == 0:
|
1865
|
+
raise RuntimeError(
|
1866
|
+
f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}."
|
1867
|
+
)
|
1868
|
+
|
1869
|
+
# Extract logprobs for the first (and only) position
|
1870
|
+
logprobs = self._extract_logprobs_for_tokens(
|
1871
|
+
output_logprobs[0], label_token_ids
|
1872
|
+
)
|
1873
|
+
score_list = self._convert_logprobs_to_scores(
|
1874
|
+
logprobs, label_token_ids, apply_softmax
|
1875
|
+
)
|
1876
|
+
scores.append(score_list)
|
1877
|
+
|
1878
|
+
return scores
|
1879
|
+
|
1681
1880
|
async def score_request(
|
1682
1881
|
self,
|
1683
1882
|
query: Optional[Union[str, List[int]]] = None,
|
@@ -1688,7 +1887,29 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
1688
1887
|
request: Optional[Any] = None,
|
1689
1888
|
) -> List[List[float]]:
|
1690
1889
|
"""
|
1691
|
-
|
1890
|
+
Score the probability of specified token IDs appearing after the given (query + item) pair.
|
1891
|
+
|
1892
|
+
This method supports two scoring approaches:
|
1893
|
+
1. Single-Item scoring (default): Process each query+item pair independently
|
1894
|
+
2. Multi-Item scoring: When multi_item_scoring_delimiter is set, combine query and
|
1895
|
+
multiple items into a single sequence using delimiter for efficient processing.
|
1896
|
+
Note: item_first parameter is ignored in multi-item scoring mode since it uses
|
1897
|
+
a fixed format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
|
1898
|
+
|
1899
|
+
Multi-item scoring works with both text and pre-tokenized inputs:
|
1900
|
+
- Text: query<delimiter_text>item1<delimiter_text>item2<delimiter_text>item3<delimiter_text>
|
1901
|
+
- Tokens: query<delimiter_token_id>item1<delimiter_token_id>item2<delimiter_token_id>item3<delimiter_token_id>
|
1902
|
+
|
1903
|
+
Args:
|
1904
|
+
query: The query text or pre-tokenized query token IDs
|
1905
|
+
items: The item text(s) or pre-tokenized item token IDs
|
1906
|
+
label_token_ids: List of token IDs to compute probabilities for
|
1907
|
+
apply_softmax: Whether to normalize probabilities using softmax
|
1908
|
+
item_first: If True, prepend items to query. Ignored for multi-item scoring.
|
1909
|
+
request: Optional FastAPI request object
|
1910
|
+
|
1911
|
+
Returns:
|
1912
|
+
List of lists containing probabilities for each item and each label token
|
1692
1913
|
"""
|
1693
1914
|
if label_token_ids is None:
|
1694
1915
|
raise ValueError("label_token_ids must be provided")
|
@@ -1701,9 +1922,17 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
1701
1922
|
f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
|
1702
1923
|
)
|
1703
1924
|
|
1925
|
+
# Check if multi-item scoring is enabled by presence of delimiter
|
1926
|
+
use_multi_item_scoring = (
|
1927
|
+
self.server_args.multi_item_scoring_delimiter is not None
|
1928
|
+
and self.multi_item_delimiter_text is not None
|
1929
|
+
)
|
1930
|
+
|
1704
1931
|
batch_request = GenerateReqInput(
|
1705
1932
|
token_ids_logprob=label_token_ids,
|
1706
1933
|
return_logprob=True,
|
1934
|
+
# Set logprob_start_len=0 for multi-item scoring since we want logprobs at all delimiter positions
|
1935
|
+
logprob_start_len=0 if use_multi_item_scoring else -1,
|
1707
1936
|
stream=False,
|
1708
1937
|
sampling_params={"max_new_tokens": 0},
|
1709
1938
|
)
|
@@ -1715,12 +1944,23 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
1715
1944
|
):
|
1716
1945
|
# Both query and items are text
|
1717
1946
|
items_list = [items] if isinstance(items, str) else items
|
1718
|
-
if item_first:
|
1719
|
-
prompts = [f"{item}{query}" for item in items_list]
|
1720
|
-
else:
|
1721
|
-
prompts = [f"{query}{item}" for item in items_list]
|
1722
1947
|
|
1723
|
-
|
1948
|
+
if use_multi_item_scoring:
|
1949
|
+
# Multi-item scoring: create single prompt with delimiter text
|
1950
|
+
# Always use format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
|
1951
|
+
# (item_first is ignored for multi-item scoring)
|
1952
|
+
delimiter = self.multi_item_delimiter_text
|
1953
|
+
combined_items = delimiter.join(items_list)
|
1954
|
+
# Add final delimiter after the last item for logprob extraction
|
1955
|
+
single_prompt = f"{query}{delimiter}{combined_items}{delimiter}"
|
1956
|
+
batch_request.text = [single_prompt]
|
1957
|
+
else:
|
1958
|
+
# Single-item scoring: create separate prompts for each item
|
1959
|
+
if item_first:
|
1960
|
+
prompts = [f"{item}{query}" for item in items_list]
|
1961
|
+
else:
|
1962
|
+
prompts = [f"{query}{item}" for item in items_list]
|
1963
|
+
batch_request.text = prompts
|
1724
1964
|
|
1725
1965
|
elif (
|
1726
1966
|
isinstance(query, list)
|
@@ -1729,61 +1969,38 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
1729
1969
|
and isinstance(items[0], list)
|
1730
1970
|
):
|
1731
1971
|
# Both query and items are token IDs
|
1732
|
-
if
|
1733
|
-
|
1972
|
+
if use_multi_item_scoring:
|
1973
|
+
# Multi-item scoring: concatenate with delimiter token ID
|
1974
|
+
# Format: query<delimiter_token_id>item1<delimiter_token_id>item2<delimiter_token_id>item3<delimiter_token_id>
|
1975
|
+
delimiter_token_id = self.server_args.multi_item_scoring_delimiter
|
1976
|
+
combined_input_ids = self._build_multi_item_token_sequence(
|
1977
|
+
query, items, delimiter_token_id
|
1978
|
+
)
|
1979
|
+
batch_request.input_ids = [combined_input_ids]
|
1734
1980
|
else:
|
1735
|
-
|
1736
|
-
|
1737
|
-
|
1981
|
+
# Single-item scoring: process each item separately
|
1982
|
+
if item_first:
|
1983
|
+
input_ids_list = [item + query for item in items]
|
1984
|
+
else:
|
1985
|
+
input_ids_list = [query + item for item in items]
|
1986
|
+
batch_request.input_ids = input_ids_list
|
1738
1987
|
else:
|
1739
1988
|
raise ValueError(
|
1740
1989
|
"Invalid combination of query/items types for score_request."
|
1741
1990
|
)
|
1742
1991
|
|
1743
1992
|
results = await self.generate_request(batch_request, request).__anext__()
|
1744
|
-
scores = []
|
1745
|
-
|
1746
|
-
for result in results:
|
1747
|
-
# Get logprobs for each token
|
1748
|
-
logprobs = {}
|
1749
|
-
|
1750
|
-
# For scoring requests, we read from output_token_ids_logprobs since we want
|
1751
|
-
# the logprobs for specific tokens mentioned in the label_token_ids at
|
1752
|
-
# the next position after the last token in the prompt
|
1753
|
-
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
|
1754
|
-
|
1755
|
-
# Check if output_logprobs is properly populated
|
1756
|
-
if (
|
1757
|
-
output_logprobs is None
|
1758
|
-
or not output_logprobs
|
1759
|
-
or len(output_logprobs) == 0
|
1760
|
-
):
|
1761
|
-
raise RuntimeError(
|
1762
|
-
f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}. "
|
1763
|
-
"This indicates token_ids_logprobs were not computed properly for the scoring request."
|
1764
|
-
)
|
1765
|
-
|
1766
|
-
for logprob, token_id, _ in output_logprobs[0]:
|
1767
|
-
if token_id in label_token_ids:
|
1768
|
-
logprobs[token_id] = logprob
|
1769
1993
|
|
1770
|
-
|
1771
|
-
|
1772
|
-
|
1773
|
-
|
1774
|
-
|
1775
|
-
|
1776
|
-
|
1777
|
-
|
1778
|
-
|
1779
|
-
|
1780
|
-
score_list = [
|
1781
|
-
math.exp(x) if x != float("-inf") else 0.0 for x in score_list
|
1782
|
-
]
|
1783
|
-
|
1784
|
-
scores.append(score_list)
|
1785
|
-
|
1786
|
-
return scores
|
1994
|
+
if use_multi_item_scoring:
|
1995
|
+
# Multi-item scoring: extract scores from input_token_ids_logprobs
|
1996
|
+
return self._process_multi_item_scoring_results(
|
1997
|
+
results, items, label_token_ids, apply_softmax, batch_request
|
1998
|
+
)
|
1999
|
+
else:
|
2000
|
+
# Single-item scoring: process each result separately
|
2001
|
+
return self._process_single_item_scoring_results(
|
2002
|
+
results, label_token_ids, apply_softmax
|
2003
|
+
)
|
1787
2004
|
|
1788
2005
|
async def watch_load_thread(self):
|
1789
2006
|
# Only for dp_controller when dp_size > 1
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -15,14 +15,12 @@
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
17
|
import logging
|
18
|
-
import
|
19
|
-
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
18
|
+
from typing import TYPE_CHECKING, Optional
|
20
19
|
|
21
20
|
import torch
|
22
21
|
|
23
22
|
from sglang.srt.configs.model_config import ModelConfig
|
24
23
|
from sglang.srt.distributed import get_pp_group, get_world_group
|
25
|
-
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
26
24
|
from sglang.srt.managers.io_struct import (
|
27
25
|
DestroyWeightsUpdateGroupReqInput,
|
28
26
|
GetWeightsByNameReqInput,
|
@@ -36,13 +34,10 @@ from sglang.srt.managers.io_struct import (
|
|
36
34
|
UpdateWeightsFromTensorReqInput,
|
37
35
|
)
|
38
36
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
37
|
+
from sglang.srt.managers.scheduler import GenerationBatchResult
|
39
38
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
40
39
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
41
|
-
from sglang.srt.model_executor.forward_batch_info import
|
42
|
-
ForwardBatch,
|
43
|
-
ForwardBatchOutput,
|
44
|
-
PPProxyTensors,
|
45
|
-
)
|
40
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
46
41
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
47
42
|
from sglang.srt.server_args import ServerArgs
|
48
43
|
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
|
@@ -236,9 +231,8 @@ class TpModelWorker:
|
|
236
231
|
def forward_batch_generation(
|
237
232
|
self,
|
238
233
|
model_worker_batch: ModelWorkerBatch,
|
239
|
-
launch_done: Optional[threading.Event] = None,
|
240
234
|
is_verify: bool = False,
|
241
|
-
) ->
|
235
|
+
) -> GenerationBatchResult:
|
242
236
|
# update the consumer index of hicache to the running batch
|
243
237
|
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
244
238
|
|
@@ -256,32 +250,49 @@ class TpModelWorker:
|
|
256
250
|
logits_output, can_run_cuda_graph = self.model_runner.forward(
|
257
251
|
forward_batch, pp_proxy_tensors=pp_proxy_tensors
|
258
252
|
)
|
259
|
-
|
260
|
-
launch_done.set()
|
261
|
-
|
262
|
-
skip_sample = is_verify or model_worker_batch.is_prefill_only
|
263
|
-
next_token_ids = None
|
264
|
-
|
265
|
-
if not skip_sample:
|
266
|
-
next_token_ids = self.model_runner.sample(logits_output, forward_batch)
|
267
|
-
elif model_worker_batch.return_logprob and not is_verify:
|
268
|
-
# NOTE: Compute logprobs without full sampling
|
269
|
-
self.model_runner.compute_logprobs_only(
|
270
|
-
logits_output, model_worker_batch
|
271
|
-
)
|
272
|
-
|
273
|
-
return ForwardBatchOutput(
|
253
|
+
batch_result = GenerationBatchResult(
|
274
254
|
logits_output=logits_output,
|
275
|
-
next_token_ids=next_token_ids,
|
276
255
|
can_run_cuda_graph=can_run_cuda_graph,
|
277
256
|
)
|
257
|
+
|
258
|
+
if is_verify:
|
259
|
+
# Skip sampling and return logits for target forward
|
260
|
+
return batch_result
|
261
|
+
|
262
|
+
if model_worker_batch.delay_sample_launch:
|
263
|
+
batch_result.delay_sample_launch = True
|
264
|
+
batch_result.forward_batch = forward_batch
|
265
|
+
return batch_result
|
266
|
+
|
267
|
+
if model_worker_batch.is_prefill_only:
|
268
|
+
# For prefill-only requests, create dummy token IDs on CPU
|
269
|
+
# The size should match the batch size (number of sequences), not total tokens
|
270
|
+
batch_result.next_token_ids = torch.zeros(
|
271
|
+
len(model_worker_batch.seq_lens),
|
272
|
+
dtype=torch.long,
|
273
|
+
device=model_worker_batch.input_ids.device,
|
274
|
+
)
|
275
|
+
if (
|
276
|
+
model_worker_batch.return_logprob
|
277
|
+
and logits_output.next_token_logits is not None
|
278
|
+
):
|
279
|
+
# NOTE: Compute logprobs without full sampling
|
280
|
+
self.model_runner.compute_logprobs_only(
|
281
|
+
logits_output, model_worker_batch
|
282
|
+
)
|
283
|
+
else:
|
284
|
+
batch_result.next_token_ids = self.model_runner.sample(
|
285
|
+
logits_output, forward_batch
|
286
|
+
)
|
287
|
+
|
288
|
+
return batch_result
|
278
289
|
else:
|
279
290
|
pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
|
280
291
|
forward_batch,
|
281
292
|
pp_proxy_tensors=pp_proxy_tensors,
|
282
293
|
)
|
283
|
-
return
|
284
|
-
|
294
|
+
return GenerationBatchResult(
|
295
|
+
pp_hidden_states_proxy_tensors=pp_proxy_tensors,
|
285
296
|
can_run_cuda_graph=can_run_cuda_graph,
|
286
297
|
)
|
287
298
|
|
@@ -274,10 +274,15 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
274
274
|
self.full_to_swa_index_mapping[free_index] = 0
|
275
275
|
|
276
276
|
def backup_state(self):
|
277
|
-
|
277
|
+
return [
|
278
|
+
self.full_attn_allocator.backup_state(),
|
279
|
+
self.swa_attn_allocator.backup_state(),
|
280
|
+
]
|
278
281
|
|
279
282
|
def restore_state(self, state):
|
280
|
-
|
283
|
+
assert len(state) == 2
|
284
|
+
self.full_attn_allocator.restore_state(state[0])
|
285
|
+
self.swa_attn_allocator.restore_state(state[1])
|
281
286
|
|
282
287
|
def clear(self):
|
283
288
|
self.swa_attn_allocator.clear()
|
@@ -60,7 +60,7 @@ class ChunkCache(BasePrefixCache):
|
|
60
60
|
]
|
61
61
|
|
62
62
|
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
63
|
-
req.prefix_indices = kv_indices
|
63
|
+
req.prefix_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
64
64
|
|
65
65
|
def evict(self, num_tokens: int):
|
66
66
|
pass
|