sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. sglang/bench_one_batch.py +0 -2
  2. sglang/bench_serving.py +224 -127
  3. sglang/compile_deep_gemm.py +3 -0
  4. sglang/launch_server.py +0 -14
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/falcon_h1.py +12 -58
  7. sglang/srt/configs/mamba_utils.py +117 -0
  8. sglang/srt/configs/model_config.py +68 -31
  9. sglang/srt/configs/nemotron_h.py +286 -0
  10. sglang/srt/configs/qwen3_next.py +11 -43
  11. sglang/srt/disaggregation/decode.py +7 -18
  12. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  13. sglang/srt/disaggregation/nixl/conn.py +55 -23
  14. sglang/srt/disaggregation/prefill.py +17 -32
  15. sglang/srt/entrypoints/engine.py +2 -2
  16. sglang/srt/entrypoints/grpc_request_manager.py +10 -23
  17. sglang/srt/entrypoints/grpc_server.py +220 -80
  18. sglang/srt/entrypoints/http_server.py +49 -1
  19. sglang/srt/entrypoints/openai/protocol.py +159 -31
  20. sglang/srt/entrypoints/openai/serving_chat.py +13 -71
  21. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  22. sglang/srt/environ.py +4 -0
  23. sglang/srt/function_call/function_call_parser.py +8 -6
  24. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  25. sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
  26. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
  27. sglang/srt/layers/attention/attention_registry.py +31 -22
  28. sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
  29. sglang/srt/layers/attention/flashattention_backend.py +0 -1
  30. sglang/srt/layers/attention/flashinfer_backend.py +223 -6
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
  32. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
  33. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  34. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
  35. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  36. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  37. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  38. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  39. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  40. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  41. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  42. sglang/srt/layers/attention/triton_backend.py +1 -1
  43. sglang/srt/layers/logits_processor.py +136 -6
  44. sglang/srt/layers/modelopt_utils.py +11 -0
  45. sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
  46. sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
  47. sglang/srt/layers/moe/ep_moe/layer.py +8 -286
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
  49. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  50. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  51. sglang/srt/layers/moe/utils.py +7 -1
  52. sglang/srt/layers/quantization/__init__.py +1 -1
  53. sglang/srt/layers/quantization/fp8.py +84 -18
  54. sglang/srt/layers/quantization/modelopt_quant.py +1 -1
  55. sglang/srt/layers/quantization/quark/quark.py +3 -1
  56. sglang/srt/layers/quantization/w4afp8.py +2 -16
  57. sglang/srt/lora/lora_manager.py +0 -8
  58. sglang/srt/managers/overlap_utils.py +18 -16
  59. sglang/srt/managers/schedule_batch.py +119 -90
  60. sglang/srt/managers/schedule_policy.py +1 -1
  61. sglang/srt/managers/scheduler.py +213 -126
  62. sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
  63. sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
  64. sglang/srt/managers/tokenizer_manager.py +270 -53
  65. sglang/srt/managers/tp_worker.py +39 -28
  66. sglang/srt/mem_cache/allocator.py +7 -2
  67. sglang/srt/mem_cache/chunk_cache.py +1 -1
  68. sglang/srt/mem_cache/memory_pool.py +162 -68
  69. sglang/srt/mem_cache/radix_cache.py +8 -3
  70. sglang/srt/mem_cache/swa_radix_cache.py +70 -14
  71. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  72. sglang/srt/model_executor/forward_batch_info.py +4 -18
  73. sglang/srt/model_executor/model_runner.py +55 -51
  74. sglang/srt/model_loader/__init__.py +1 -1
  75. sglang/srt/model_loader/loader.py +187 -6
  76. sglang/srt/model_loader/weight_utils.py +3 -0
  77. sglang/srt/models/falcon_h1.py +11 -9
  78. sglang/srt/models/gemma3_mm.py +16 -0
  79. sglang/srt/models/grok.py +5 -13
  80. sglang/srt/models/mixtral.py +1 -3
  81. sglang/srt/models/mllama4.py +11 -1
  82. sglang/srt/models/nemotron_h.py +514 -0
  83. sglang/srt/models/utils.py +5 -1
  84. sglang/srt/sampling/sampling_batch_info.py +11 -9
  85. sglang/srt/server_args.py +100 -33
  86. sglang/srt/speculative/eagle_worker.py +11 -13
  87. sglang/srt/speculative/ngram_worker.py +12 -11
  88. sglang/srt/speculative/spec_utils.py +0 -1
  89. sglang/srt/two_batch_overlap.py +1 -0
  90. sglang/srt/utils/common.py +18 -0
  91. sglang/srt/utils/hf_transformers_utils.py +2 -0
  92. sglang/test/longbench_v2/__init__.py +1 -0
  93. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  94. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  95. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  96. sglang/test/run_eval.py +40 -0
  97. sglang/test/simple_eval_longbench_v2.py +332 -0
  98. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  99. sglang/test/test_deterministic.py +18 -2
  100. sglang/test/test_deterministic_utils.py +81 -0
  101. sglang/test/test_disaggregation_utils.py +63 -0
  102. sglang/test/test_utils.py +32 -11
  103. sglang/version.py +1 -1
  104. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
  105. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
  106. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  107. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  108. sglang/test/test_block_fp8_ep.py +0 -358
  109. /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
  110. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  111. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  112. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -182,6 +182,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
182
182
  if speculative_algorithm.is_none()
183
183
  else server_args.speculative_num_draft_tokens
184
184
  )
185
+ # Initialize delimiter text for multi-item scoring (will be set after tokenizer is loaded)
186
+ self.multi_item_delimiter_text = None
185
187
 
186
188
  if self.model_config.is_multimodal:
187
189
  import_processors("sglang.srt.multimodal.processors")
@@ -223,6 +225,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
223
225
  self.processor = _processor
224
226
  self.tokenizer = get_tokenizer_from_processor(self.processor)
225
227
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
228
+ self._initialize_multi_item_delimiter_text()
226
229
  else:
227
230
  self.mm_processor = self.processor = None
228
231
 
@@ -235,6 +238,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
235
238
  trust_remote_code=server_args.trust_remote_code,
236
239
  revision=server_args.revision,
237
240
  )
241
+ self._initialize_multi_item_delimiter_text()
238
242
  # Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
239
243
  if (
240
244
  server_args.enable_dynamic_batch_tokenizer
@@ -1678,6 +1682,201 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1678
1682
  if len(self.model_update_tmp) == self.server_args.dp_size:
1679
1683
  self.model_update_result.set_result(self.model_update_tmp)
1680
1684
 
1685
+ def _initialize_multi_item_delimiter_text(self):
1686
+ """Initialize multi-item delimiter text from token ID after tokenizer is loaded."""
1687
+ if (
1688
+ hasattr(self.server_args, "multi_item_scoring_delimiter")
1689
+ and self.server_args.multi_item_scoring_delimiter is not None
1690
+ and self.tokenizer is not None
1691
+ ):
1692
+ try:
1693
+ self.multi_item_delimiter_text = self.tokenizer.decode(
1694
+ [self.server_args.multi_item_scoring_delimiter],
1695
+ skip_special_tokens=False,
1696
+ )
1697
+ except Exception as e:
1698
+ logger.warning(
1699
+ f"Failed to decode delimiter token {self.server_args.multi_item_scoring_delimiter}: {e}"
1700
+ )
1701
+ self.multi_item_delimiter_text = None
1702
+
1703
+ def _build_multi_item_token_sequence(
1704
+ self, query: List[int], items: List[List[int]], delimiter_token_id: int
1705
+ ) -> List[int]:
1706
+ """
1707
+ Build a single token sequence for multi-item scoring.
1708
+ Format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
1709
+
1710
+ Args:
1711
+ query: Query token IDs
1712
+ items: List of item token ID sequences
1713
+ delimiter_token_id: Token ID to use as delimiter
1714
+
1715
+ Returns:
1716
+ Combined token sequence
1717
+ """
1718
+ combined_sequence = query[:] # Start with query
1719
+
1720
+ for item in items:
1721
+ combined_sequence.append(delimiter_token_id) # Add delimiter
1722
+ combined_sequence.extend(item) # Add item tokens
1723
+
1724
+ # Add final delimiter after the last item for logprob extraction
1725
+ combined_sequence.append(delimiter_token_id)
1726
+
1727
+ return combined_sequence
1728
+
1729
+ def _extract_logprobs_for_tokens(
1730
+ self, logprobs_data: List, label_token_ids: List[int]
1731
+ ) -> Dict[int, float]:
1732
+ """
1733
+ Extract logprobs for specified token IDs from logprobs data.
1734
+
1735
+ Args:
1736
+ logprobs_data: List of (logprob, token_id, text) tuples
1737
+ label_token_ids: Token IDs to extract logprobs for
1738
+
1739
+ Returns:
1740
+ Dictionary mapping token_id to logprob
1741
+ """
1742
+ logprobs = {}
1743
+ if logprobs_data:
1744
+ for logprob, token_id, _ in logprobs_data:
1745
+ if token_id in label_token_ids:
1746
+ logprobs[token_id] = logprob
1747
+ return logprobs
1748
+
1749
+ def _convert_logprobs_to_scores(
1750
+ self,
1751
+ logprobs: Dict[int, float],
1752
+ label_token_ids: List[int],
1753
+ apply_softmax: bool,
1754
+ ) -> List[float]:
1755
+ """
1756
+ Convert logprobs dictionary to ordered score list.
1757
+
1758
+ Args:
1759
+ logprobs: Dictionary mapping token_id to logprob
1760
+ label_token_ids: Token IDs in desired order
1761
+ apply_softmax: Whether to apply softmax normalization
1762
+
1763
+ Returns:
1764
+ List of scores in the same order as label_token_ids
1765
+ """
1766
+ score_list = [
1767
+ logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
1768
+ ]
1769
+
1770
+ if apply_softmax:
1771
+ score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
1772
+ else:
1773
+ # Convert logprobs to probabilities if not using softmax
1774
+ score_list = [
1775
+ math.exp(x) if x != float("-inf") else 0.0 for x in score_list
1776
+ ]
1777
+
1778
+ return score_list
1779
+
1780
+ def _process_multi_item_scoring_results(
1781
+ self,
1782
+ results: Any,
1783
+ items: List,
1784
+ label_token_ids: List[int],
1785
+ apply_softmax: bool,
1786
+ batch_request=None,
1787
+ ) -> List[List[float]]:
1788
+ """
1789
+ Process results from multi-item scoring request.
1790
+ Extracts logprobs at delimiter positions from input_token_ids_logprobs.
1791
+
1792
+ Args:
1793
+ results: Results from generate_request
1794
+ items: List of items being scored
1795
+ label_token_ids: Token IDs to extract scores for
1796
+ apply_softmax: Whether to apply softmax normalization
1797
+ batch_request: The original batch request containing input sequence
1798
+
1799
+ Returns:
1800
+ List of score lists, one for each item
1801
+ """
1802
+ single_result = results[0] if isinstance(results, list) else results
1803
+
1804
+ # For multi-item scoring, logprobs are in input_token_ids_logprobs
1805
+ input_logprobs = single_result["meta_info"].get("input_token_ids_logprobs", [])
1806
+
1807
+ if not input_logprobs:
1808
+ raise RuntimeError(
1809
+ f"input_token_ids_logprobs is empty for multi-item scoring request {single_result['meta_info'].get('id', '<unknown>')}. "
1810
+ "This indicates token_ids_logprobs were not computed properly for Mutil Item Scoring."
1811
+ )
1812
+
1813
+ scores = []
1814
+ num_items = len(items) if isinstance(items, list) else 1
1815
+
1816
+ # Check if we have the expected number of logprobs
1817
+ expected_logprobs_count = num_items + 1
1818
+ if len(input_logprobs) != expected_logprobs_count:
1819
+ raise RuntimeError(
1820
+ f"Expected {expected_logprobs_count} input_token_ids_logprobs for multi-item scoring "
1821
+ f"with {num_items} items, but got {len(input_logprobs)}. "
1822
+ f"Request ID: {single_result['meta_info'].get('id', '<unknown>')}"
1823
+ )
1824
+
1825
+ # Skip the first delimiter (between query and first item) and process remaining delimiter positions
1826
+ # We want to exclude the first one since it represents the boundary between query and first item, not an item boundary
1827
+ start_idx = 1 if len(input_logprobs) > 1 else 0
1828
+
1829
+ # Process logprobs for each item position (excluding first delimiter)
1830
+ for item_idx in range(num_items):
1831
+ logprob_idx = start_idx + item_idx
1832
+ item_logprobs_data = input_logprobs[logprob_idx]
1833
+ logprobs = self._extract_logprobs_for_tokens(
1834
+ item_logprobs_data, label_token_ids
1835
+ )
1836
+ score_list = self._convert_logprobs_to_scores(
1837
+ logprobs, label_token_ids, apply_softmax
1838
+ )
1839
+ scores.append(score_list)
1840
+
1841
+ return scores
1842
+
1843
+ def _process_single_item_scoring_results(
1844
+ self, results: Any, label_token_ids: List[int], apply_softmax: bool
1845
+ ) -> List[List[float]]:
1846
+ """
1847
+ Process results from single-item scoring request.
1848
+ Single-item scoring results are stored in output_token_ids_logprobs.
1849
+
1850
+ Args:
1851
+ results: Results from generate_request
1852
+ label_token_ids: Token IDs to extract scores for
1853
+ apply_softmax: Whether to apply softmax normalization
1854
+
1855
+ Returns:
1856
+ List of score lists, one for each result
1857
+ """
1858
+ scores = []
1859
+
1860
+ for result in results:
1861
+ # For single-item scoring, logprobs are in output_token_ids_logprobs
1862
+ output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
1863
+
1864
+ if not output_logprobs or len(output_logprobs) == 0:
1865
+ raise RuntimeError(
1866
+ f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}."
1867
+ )
1868
+
1869
+ # Extract logprobs for the first (and only) position
1870
+ logprobs = self._extract_logprobs_for_tokens(
1871
+ output_logprobs[0], label_token_ids
1872
+ )
1873
+ score_list = self._convert_logprobs_to_scores(
1874
+ logprobs, label_token_ids, apply_softmax
1875
+ )
1876
+ scores.append(score_list)
1877
+
1878
+ return scores
1879
+
1681
1880
  async def score_request(
1682
1881
  self,
1683
1882
  query: Optional[Union[str, List[int]]] = None,
@@ -1688,7 +1887,29 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1688
1887
  request: Optional[Any] = None,
1689
1888
  ) -> List[List[float]]:
1690
1889
  """
1691
- See Engine.score() for more details.
1890
+ Score the probability of specified token IDs appearing after the given (query + item) pair.
1891
+
1892
+ This method supports two scoring approaches:
1893
+ 1. Single-Item scoring (default): Process each query+item pair independently
1894
+ 2. Multi-Item scoring: When multi_item_scoring_delimiter is set, combine query and
1895
+ multiple items into a single sequence using delimiter for efficient processing.
1896
+ Note: item_first parameter is ignored in multi-item scoring mode since it uses
1897
+ a fixed format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
1898
+
1899
+ Multi-item scoring works with both text and pre-tokenized inputs:
1900
+ - Text: query<delimiter_text>item1<delimiter_text>item2<delimiter_text>item3<delimiter_text>
1901
+ - Tokens: query<delimiter_token_id>item1<delimiter_token_id>item2<delimiter_token_id>item3<delimiter_token_id>
1902
+
1903
+ Args:
1904
+ query: The query text or pre-tokenized query token IDs
1905
+ items: The item text(s) or pre-tokenized item token IDs
1906
+ label_token_ids: List of token IDs to compute probabilities for
1907
+ apply_softmax: Whether to normalize probabilities using softmax
1908
+ item_first: If True, prepend items to query. Ignored for multi-item scoring.
1909
+ request: Optional FastAPI request object
1910
+
1911
+ Returns:
1912
+ List of lists containing probabilities for each item and each label token
1692
1913
  """
1693
1914
  if label_token_ids is None:
1694
1915
  raise ValueError("label_token_ids must be provided")
@@ -1701,9 +1922,17 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1701
1922
  f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
1702
1923
  )
1703
1924
 
1925
+ # Check if multi-item scoring is enabled by presence of delimiter
1926
+ use_multi_item_scoring = (
1927
+ self.server_args.multi_item_scoring_delimiter is not None
1928
+ and self.multi_item_delimiter_text is not None
1929
+ )
1930
+
1704
1931
  batch_request = GenerateReqInput(
1705
1932
  token_ids_logprob=label_token_ids,
1706
1933
  return_logprob=True,
1934
+ # Set logprob_start_len=0 for multi-item scoring since we want logprobs at all delimiter positions
1935
+ logprob_start_len=0 if use_multi_item_scoring else -1,
1707
1936
  stream=False,
1708
1937
  sampling_params={"max_new_tokens": 0},
1709
1938
  )
@@ -1715,12 +1944,23 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1715
1944
  ):
1716
1945
  # Both query and items are text
1717
1946
  items_list = [items] if isinstance(items, str) else items
1718
- if item_first:
1719
- prompts = [f"{item}{query}" for item in items_list]
1720
- else:
1721
- prompts = [f"{query}{item}" for item in items_list]
1722
1947
 
1723
- batch_request.text = prompts
1948
+ if use_multi_item_scoring:
1949
+ # Multi-item scoring: create single prompt with delimiter text
1950
+ # Always use format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
1951
+ # (item_first is ignored for multi-item scoring)
1952
+ delimiter = self.multi_item_delimiter_text
1953
+ combined_items = delimiter.join(items_list)
1954
+ # Add final delimiter after the last item for logprob extraction
1955
+ single_prompt = f"{query}{delimiter}{combined_items}{delimiter}"
1956
+ batch_request.text = [single_prompt]
1957
+ else:
1958
+ # Single-item scoring: create separate prompts for each item
1959
+ if item_first:
1960
+ prompts = [f"{item}{query}" for item in items_list]
1961
+ else:
1962
+ prompts = [f"{query}{item}" for item in items_list]
1963
+ batch_request.text = prompts
1724
1964
 
1725
1965
  elif (
1726
1966
  isinstance(query, list)
@@ -1729,61 +1969,38 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1729
1969
  and isinstance(items[0], list)
1730
1970
  ):
1731
1971
  # Both query and items are token IDs
1732
- if item_first:
1733
- input_ids_list = [item + query for item in items]
1972
+ if use_multi_item_scoring:
1973
+ # Multi-item scoring: concatenate with delimiter token ID
1974
+ # Format: query<delimiter_token_id>item1<delimiter_token_id>item2<delimiter_token_id>item3<delimiter_token_id>
1975
+ delimiter_token_id = self.server_args.multi_item_scoring_delimiter
1976
+ combined_input_ids = self._build_multi_item_token_sequence(
1977
+ query, items, delimiter_token_id
1978
+ )
1979
+ batch_request.input_ids = [combined_input_ids]
1734
1980
  else:
1735
- input_ids_list = [query + item for item in items]
1736
-
1737
- batch_request.input_ids = input_ids_list
1981
+ # Single-item scoring: process each item separately
1982
+ if item_first:
1983
+ input_ids_list = [item + query for item in items]
1984
+ else:
1985
+ input_ids_list = [query + item for item in items]
1986
+ batch_request.input_ids = input_ids_list
1738
1987
  else:
1739
1988
  raise ValueError(
1740
1989
  "Invalid combination of query/items types for score_request."
1741
1990
  )
1742
1991
 
1743
1992
  results = await self.generate_request(batch_request, request).__anext__()
1744
- scores = []
1745
-
1746
- for result in results:
1747
- # Get logprobs for each token
1748
- logprobs = {}
1749
-
1750
- # For scoring requests, we read from output_token_ids_logprobs since we want
1751
- # the logprobs for specific tokens mentioned in the label_token_ids at
1752
- # the next position after the last token in the prompt
1753
- output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
1754
-
1755
- # Check if output_logprobs is properly populated
1756
- if (
1757
- output_logprobs is None
1758
- or not output_logprobs
1759
- or len(output_logprobs) == 0
1760
- ):
1761
- raise RuntimeError(
1762
- f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}. "
1763
- "This indicates token_ids_logprobs were not computed properly for the scoring request."
1764
- )
1765
-
1766
- for logprob, token_id, _ in output_logprobs[0]:
1767
- if token_id in label_token_ids:
1768
- logprobs[token_id] = logprob
1769
1993
 
1770
- # Get scores in order of label_token_ids
1771
- score_list = [
1772
- logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
1773
- ]
1774
-
1775
- # Apply softmax to logprobs if needed
1776
- if apply_softmax:
1777
- score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
1778
- else:
1779
- # Convert logprobs to probabilities if not using softmax
1780
- score_list = [
1781
- math.exp(x) if x != float("-inf") else 0.0 for x in score_list
1782
- ]
1783
-
1784
- scores.append(score_list)
1785
-
1786
- return scores
1994
+ if use_multi_item_scoring:
1995
+ # Multi-item scoring: extract scores from input_token_ids_logprobs
1996
+ return self._process_multi_item_scoring_results(
1997
+ results, items, label_token_ids, apply_softmax, batch_request
1998
+ )
1999
+ else:
2000
+ # Single-item scoring: process each result separately
2001
+ return self._process_single_item_scoring_results(
2002
+ results, label_token_ids, apply_softmax
2003
+ )
1787
2004
 
1788
2005
  async def watch_load_thread(self):
1789
2006
  # Only for dp_controller when dp_size > 1
@@ -15,14 +15,12 @@
15
15
  from __future__ import annotations
16
16
 
17
17
  import logging
18
- import threading
19
- from typing import TYPE_CHECKING, Optional, Tuple, Union
18
+ from typing import TYPE_CHECKING, Optional
20
19
 
21
20
  import torch
22
21
 
23
22
  from sglang.srt.configs.model_config import ModelConfig
24
23
  from sglang.srt.distributed import get_pp_group, get_world_group
25
- from sglang.srt.layers.logits_processor import LogitsProcessorOutput
26
24
  from sglang.srt.managers.io_struct import (
27
25
  DestroyWeightsUpdateGroupReqInput,
28
26
  GetWeightsByNameReqInput,
@@ -36,13 +34,10 @@ from sglang.srt.managers.io_struct import (
36
34
  UpdateWeightsFromTensorReqInput,
37
35
  )
38
36
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
37
+ from sglang.srt.managers.scheduler import GenerationBatchResult
39
38
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
40
39
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
41
- from sglang.srt.model_executor.forward_batch_info import (
42
- ForwardBatch,
43
- ForwardBatchOutput,
44
- PPProxyTensors,
45
- )
40
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
46
41
  from sglang.srt.model_executor.model_runner import ModelRunner
47
42
  from sglang.srt.server_args import ServerArgs
48
43
  from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
@@ -236,9 +231,8 @@ class TpModelWorker:
236
231
  def forward_batch_generation(
237
232
  self,
238
233
  model_worker_batch: ModelWorkerBatch,
239
- launch_done: Optional[threading.Event] = None,
240
234
  is_verify: bool = False,
241
- ) -> ForwardBatchOutput:
235
+ ) -> GenerationBatchResult:
242
236
  # update the consumer index of hicache to the running batch
243
237
  self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
244
238
 
@@ -256,32 +250,49 @@ class TpModelWorker:
256
250
  logits_output, can_run_cuda_graph = self.model_runner.forward(
257
251
  forward_batch, pp_proxy_tensors=pp_proxy_tensors
258
252
  )
259
- if launch_done is not None:
260
- launch_done.set()
261
-
262
- skip_sample = is_verify or model_worker_batch.is_prefill_only
263
- next_token_ids = None
264
-
265
- if not skip_sample:
266
- next_token_ids = self.model_runner.sample(logits_output, forward_batch)
267
- elif model_worker_batch.return_logprob and not is_verify:
268
- # NOTE: Compute logprobs without full sampling
269
- self.model_runner.compute_logprobs_only(
270
- logits_output, model_worker_batch
271
- )
272
-
273
- return ForwardBatchOutput(
253
+ batch_result = GenerationBatchResult(
274
254
  logits_output=logits_output,
275
- next_token_ids=next_token_ids,
276
255
  can_run_cuda_graph=can_run_cuda_graph,
277
256
  )
257
+
258
+ if is_verify:
259
+ # Skip sampling and return logits for target forward
260
+ return batch_result
261
+
262
+ if model_worker_batch.delay_sample_launch:
263
+ batch_result.delay_sample_launch = True
264
+ batch_result.forward_batch = forward_batch
265
+ return batch_result
266
+
267
+ if model_worker_batch.is_prefill_only:
268
+ # For prefill-only requests, create dummy token IDs on CPU
269
+ # The size should match the batch size (number of sequences), not total tokens
270
+ batch_result.next_token_ids = torch.zeros(
271
+ len(model_worker_batch.seq_lens),
272
+ dtype=torch.long,
273
+ device=model_worker_batch.input_ids.device,
274
+ )
275
+ if (
276
+ model_worker_batch.return_logprob
277
+ and logits_output.next_token_logits is not None
278
+ ):
279
+ # NOTE: Compute logprobs without full sampling
280
+ self.model_runner.compute_logprobs_only(
281
+ logits_output, model_worker_batch
282
+ )
283
+ else:
284
+ batch_result.next_token_ids = self.model_runner.sample(
285
+ logits_output, forward_batch
286
+ )
287
+
288
+ return batch_result
278
289
  else:
279
290
  pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
280
291
  forward_batch,
281
292
  pp_proxy_tensors=pp_proxy_tensors,
282
293
  )
283
- return ForwardBatchOutput(
284
- pp_proxy_tensors=pp_proxy_tensors,
294
+ return GenerationBatchResult(
295
+ pp_hidden_states_proxy_tensors=pp_proxy_tensors,
285
296
  can_run_cuda_graph=can_run_cuda_graph,
286
297
  )
287
298
 
@@ -274,10 +274,15 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
274
274
  self.full_to_swa_index_mapping[free_index] = 0
275
275
 
276
276
  def backup_state(self):
277
- raise NotImplementedError
277
+ return [
278
+ self.full_attn_allocator.backup_state(),
279
+ self.swa_attn_allocator.backup_state(),
280
+ ]
278
281
 
279
282
  def restore_state(self, state):
280
- raise NotImplementedError
283
+ assert len(state) == 2
284
+ self.full_attn_allocator.restore_state(state[0])
285
+ self.swa_attn_allocator.restore_state(state[1])
281
286
 
282
287
  def clear(self):
283
288
  self.swa_attn_allocator.clear()
@@ -60,7 +60,7 @@ class ChunkCache(BasePrefixCache):
60
60
  ]
61
61
 
62
62
  # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
63
- req.prefix_indices = kv_indices
63
+ req.prefix_indices = kv_indices.to(dtype=torch.int64, copy=True)
64
64
 
65
65
  def evict(self, num_tokens: int):
66
66
  pass