sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ import random
6
6
  import threading
7
7
  import warnings
8
8
  from collections import deque
9
+ from contextlib import nullcontext
9
10
  from enum import Enum
10
11
  from typing import TYPE_CHECKING, List, Optional
11
12
 
@@ -14,15 +15,15 @@ import requests
14
15
  import torch
15
16
  import torch.distributed as dist
16
17
 
17
- from sglang.srt.utils import get_ip, get_local_ip_by_remote
18
+ from sglang.srt.utils import get_ip
18
19
 
19
20
  if TYPE_CHECKING:
20
21
  from sglang.srt.managers.schedule_batch import Req
21
22
 
22
- FakeBootstrapHost = "2.2.2.2"
23
-
24
- # env var for testing failure, convert to float explicitly
25
- FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0))
23
+ #########################
24
+ # Constants & Enums
25
+ #########################
26
+ FAKE_BOOTSTRAP_HOST = "2.2.2.2"
26
27
 
27
28
 
28
29
  class DisaggregationMode(Enum):
@@ -31,6 +32,14 @@ class DisaggregationMode(Enum):
31
32
  DECODE = "decode"
32
33
 
33
34
 
35
+ #########################
36
+ # Synchronization
37
+ #########################
38
+
39
+ # env var for testing failure, convert to float explicitly
40
+ FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0))
41
+
42
+
34
43
  def poll_and_all_reduce(pollers, gloo_group):
35
44
  # at a certain prob, the poll is failed to simulate failure
36
45
  if FAILURE_PROB > 0:
@@ -47,6 +56,11 @@ def poll_and_all_reduce(pollers, gloo_group):
47
56
  return tensor_to_reduce.tolist()
48
57
 
49
58
 
59
+ #########################
60
+ # Metadata Buffers
61
+ #########################
62
+
63
+
50
64
  class ReqToMetadataIdxAllocator:
51
65
  """A memory pool that maps a request to its first output token location."""
52
66
 
@@ -70,6 +84,118 @@ class ReqToMetadataIdxAllocator:
70
84
  self.free_slots.append(free_index)
71
85
 
72
86
 
87
+ class MetadataBuffers:
88
+ def __init__(
89
+ self,
90
+ size: int,
91
+ hidden_size: int,
92
+ dtype: torch.dtype,
93
+ max_top_logprobs_num: int = 128,
94
+ custom_mem_pool: torch.cuda.MemPool = None,
95
+ ):
96
+ self.custom_mem_pool = custom_mem_pool
97
+ device = "cuda" if self.custom_mem_pool else "cpu"
98
+
99
+ with (
100
+ torch.cuda.use_mem_pool(self.custom_mem_pool)
101
+ if self.custom_mem_pool
102
+ else nullcontext()
103
+ ):
104
+ # TODO: abort top_logprobs_num > 128 in PD
105
+
106
+ # We transfer the metadata of first output token to decode
107
+ # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
108
+ self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
109
+
110
+ self.output_hidden_states = torch.zeros(
111
+ (size, hidden_size), dtype=dtype, device=device
112
+ )
113
+ self.output_token_logprobs_val = torch.zeros(
114
+ (size, 16), dtype=torch.float32, device=device
115
+ )
116
+ self.output_token_logprobs_idx = torch.zeros(
117
+ (size, 16), dtype=torch.int32, device=device
118
+ )
119
+ self.output_top_logprobs_val = torch.zeros(
120
+ (size, max_top_logprobs_num), dtype=torch.float32, device=device
121
+ )
122
+ self.output_top_logprobs_idx = torch.zeros(
123
+ (size, max_top_logprobs_num), dtype=torch.int32, device=device
124
+ )
125
+
126
+ def get_buf_infos(self):
127
+ ptrs = [
128
+ self.output_ids.data_ptr(),
129
+ self.output_hidden_states.data_ptr(), # TODO: set None to avoid transfer hidden_states when spec_algorithm is None
130
+ self.output_token_logprobs_val.data_ptr(),
131
+ self.output_token_logprobs_idx.data_ptr(),
132
+ self.output_top_logprobs_val.data_ptr(),
133
+ self.output_top_logprobs_idx.data_ptr(),
134
+ ]
135
+ data_lens = [
136
+ self.output_ids.nbytes,
137
+ self.output_hidden_states.nbytes,
138
+ self.output_token_logprobs_val.nbytes,
139
+ self.output_token_logprobs_idx.nbytes,
140
+ self.output_top_logprobs_val.nbytes,
141
+ self.output_top_logprobs_idx.nbytes,
142
+ ]
143
+ item_lens = [
144
+ self.output_ids[0].nbytes,
145
+ self.output_hidden_states[0].nbytes,
146
+ self.output_token_logprobs_val[0].nbytes,
147
+ self.output_token_logprobs_idx[0].nbytes,
148
+ self.output_top_logprobs_val[0].nbytes,
149
+ self.output_top_logprobs_idx[0].nbytes,
150
+ ]
151
+ return ptrs, data_lens, item_lens
152
+
153
+ def get_buf(self, idx: int):
154
+ return (
155
+ self.output_ids[idx],
156
+ self.output_hidden_states[idx],
157
+ self.output_token_logprobs_val[idx],
158
+ self.output_token_logprobs_idx[idx],
159
+ self.output_top_logprobs_val[idx],
160
+ self.output_top_logprobs_idx[idx],
161
+ )
162
+
163
+ def set_buf(self, req: Req):
164
+
165
+ self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
166
+ if req.hidden_states_tensor is not None:
167
+ self.output_hidden_states[req.metadata_buffer_index].copy_(
168
+ req.hidden_states_tensor
169
+ )
170
+ if req.return_logprob:
171
+ if req.output_token_logprobs_val: # not none or empty list
172
+ self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
173
+ req.output_token_logprobs_val[0]
174
+ )
175
+ if req.output_token_logprobs_idx: # not none or empty list
176
+ self.output_token_logprobs_idx[req.metadata_buffer_index][0] = (
177
+ req.output_token_logprobs_idx[0]
178
+ )
179
+
180
+ if req.output_top_logprobs_val: # not none or empty list
181
+ self.output_top_logprobs_val[req.metadata_buffer_index][
182
+ : len(req.output_top_logprobs_val[0])
183
+ ] = torch.tensor(
184
+ req.output_top_logprobs_val[0], dtype=torch.float32, device="cpu"
185
+ )
186
+ if req.output_top_logprobs_idx: # not none or empty list
187
+ self.output_top_logprobs_idx[req.metadata_buffer_index][
188
+ : len(req.output_top_logprobs_idx[0])
189
+ ] = torch.tensor(
190
+ req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
191
+ )
192
+
193
+
194
+ #########################
195
+ # Transfer Backend
196
+ #########################
197
+
198
+
73
199
  class TransferBackend(Enum):
74
200
  MOONCAKE = "mooncake"
75
201
  NIXL = "nixl"
@@ -77,6 +203,7 @@ class TransferBackend(Enum):
77
203
 
78
204
 
79
205
  class KVClassType(Enum):
206
+ KVARGS = "kvargs"
80
207
  MANAGER = "manager"
81
208
  SENDER = "sender"
82
209
  RECEIVER = "receiver"
@@ -87,6 +214,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
87
214
  from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
88
215
 
89
216
  if transfer_backend == TransferBackend.MOONCAKE:
217
+ from sglang.srt.disaggregation.base import KVArgs
90
218
  from sglang.srt.disaggregation.mooncake import (
91
219
  MooncakeKVBootstrapServer,
92
220
  MooncakeKVManager,
@@ -95,13 +223,15 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
95
223
  )
96
224
 
97
225
  class_mapping = {
226
+ KVClassType.KVARGS: KVArgs,
98
227
  KVClassType.MANAGER: MooncakeKVManager,
99
228
  KVClassType.SENDER: MooncakeKVSender,
100
229
  KVClassType.RECEIVER: (MooncakeKVReceiver),
101
230
  KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
102
231
  }
103
232
  return class_mapping.get(class_type)
104
- if transfer_backend == TransferBackend.NIXL:
233
+ elif transfer_backend == TransferBackend.NIXL:
234
+ from sglang.srt.disaggregation.base import KVArgs
105
235
  from sglang.srt.disaggregation.nixl import (
106
236
  NixlKVBootstrapServer,
107
237
  NixlKVManager,
@@ -110,16 +240,19 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
110
240
  )
111
241
 
112
242
  class_mapping = {
243
+ KVClassType.KVARGS: KVArgs,
113
244
  KVClassType.MANAGER: NixlKVManager,
114
245
  KVClassType.SENDER: NixlKVSender,
115
246
  KVClassType.RECEIVER: (NixlKVReceiver),
116
247
  KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
117
248
  }
118
249
  return class_mapping.get(class_type)
119
- if transfer_backend == TransferBackend.FAKE:
250
+ elif transfer_backend == TransferBackend.FAKE:
251
+ from sglang.srt.disaggregation.base import KVArgs
120
252
  from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
121
253
 
122
254
  class_mapping = {
255
+ KVClassType.KVARGS: KVArgs,
123
256
  KVClassType.SENDER: FakeKVSender,
124
257
  KVClassType.RECEIVER: (FakeKVReceiver),
125
258
  }
@@ -128,6 +261,11 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
128
261
  raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
129
262
 
130
263
 
264
+ #########################
265
+ # KV Pages
266
+ #########################
267
+
268
+
131
269
  def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
132
270
  # 1. The page is guaranteed to be full except the last page.
133
271
  # 2. page index = kv_index // page_size
@@ -143,6 +281,11 @@ def kv_to_page_num(num_kv_indices: int, page_size: int):
143
281
  return (num_kv_indices + page_size - 1) // page_size
144
282
 
145
283
 
284
+ #########################
285
+ # PDLB Registry
286
+ #########################
287
+
288
+
146
289
  @dataclasses.dataclass
147
290
  class PDRegistryRequest:
148
291
  """A request to register a machine itself to the LB."""
@@ -181,6 +324,11 @@ def register_disaggregation_server(
181
324
  )
182
325
 
183
326
 
327
+ #########################
328
+ # Misc
329
+ #########################
330
+
331
+
184
332
  def is_mla_backend(target_kv_pool) -> bool:
185
333
  from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
186
334
 
@@ -200,119 +348,3 @@ def prepare_abort(req: Req, error_message: str, status_code=None):
200
348
  req.input_top_logprobs_idx = []
201
349
  req.input_token_ids_logprobs_val = []
202
350
  req.input_token_ids_logprobs_idx = []
203
-
204
-
205
- class MetadataBuffers:
206
- def __init__(self, size: int, max_top_logprobs_num: int = 128):
207
- # TODO: abort top_logprobs_num > 128 in PD
208
-
209
- # We transfer the metadata of first output token to decode
210
- # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
211
- self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu")
212
- self.output_token_logprobs_val = torch.zeros(
213
- (size, 16), dtype=torch.float32, device="cpu"
214
- )
215
- self.output_token_logprobs_idx = torch.zeros(
216
- (size, 16), dtype=torch.int32, device="cpu"
217
- )
218
- self.output_top_logprobs_val = torch.zeros(
219
- (size, max_top_logprobs_num), dtype=torch.float32, device="cpu"
220
- )
221
- self.output_top_logprobs_idx = torch.zeros(
222
- (size, max_top_logprobs_num), dtype=torch.int32, device="cpu"
223
- )
224
-
225
- def get_buf_infos(self):
226
- ptrs = [
227
- self.output_ids.data_ptr(),
228
- self.output_token_logprobs_val.data_ptr(),
229
- self.output_token_logprobs_idx.data_ptr(),
230
- self.output_top_logprobs_val.data_ptr(),
231
- self.output_top_logprobs_idx.data_ptr(),
232
- ]
233
- data_lens = [
234
- self.output_ids.nbytes,
235
- self.output_token_logprobs_val.nbytes,
236
- self.output_token_logprobs_idx.nbytes,
237
- self.output_top_logprobs_val.nbytes,
238
- self.output_top_logprobs_idx.nbytes,
239
- ]
240
- item_lens = [
241
- self.output_ids[0].nbytes,
242
- self.output_token_logprobs_val[0].nbytes,
243
- self.output_token_logprobs_idx[0].nbytes,
244
- self.output_top_logprobs_val[0].nbytes,
245
- self.output_top_logprobs_idx[0].nbytes,
246
- ]
247
- return ptrs, data_lens, item_lens
248
-
249
- def get_buf(self, idx: int):
250
- return (
251
- self.output_ids[idx],
252
- self.output_token_logprobs_val[idx],
253
- self.output_token_logprobs_idx[idx],
254
- self.output_top_logprobs_val[idx],
255
- self.output_top_logprobs_idx[idx],
256
- )
257
-
258
- def set_buf(self, req: Req):
259
-
260
- self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
261
- if req.return_logprob:
262
- if req.output_token_logprobs_val: # not none or empty list
263
- self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
264
- req.output_token_logprobs_val[0]
265
- )
266
- if req.output_token_logprobs_idx: # not none or empty list
267
- self.output_token_logprobs_idx[req.metadata_buffer_index][0] = (
268
- req.output_token_logprobs_idx[0]
269
- )
270
-
271
- if req.output_top_logprobs_val: # not none or empty list
272
- self.output_top_logprobs_val[req.metadata_buffer_index][
273
- : len(req.output_top_logprobs_val[0])
274
- ] = torch.tensor(
275
- req.output_top_logprobs_val[0], dtype=torch.float32, device="cpu"
276
- )
277
- if req.output_top_logprobs_idx: # not none or empty list
278
- self.output_top_logprobs_idx[req.metadata_buffer_index][
279
- : len(req.output_top_logprobs_idx[0])
280
- ] = torch.tensor(
281
- req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
282
- )
283
-
284
-
285
- class FastQueue:
286
- def __init__(self):
287
- self._buf = deque()
288
- self._cond = threading.Condition()
289
-
290
- def put(self, item):
291
- with self._cond:
292
- self._buf.append(item)
293
- # wake up a thread of wait()
294
- self._cond.notify()
295
-
296
- def get(self):
297
- with self._cond:
298
- # if queue is empty ,block until is notified()
299
- while not self._buf:
300
- self._cond.wait()
301
- return self._buf.popleft()
302
-
303
-
304
- def group_concurrent_contiguous(
305
- src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
306
- ) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
307
- """Vectorised NumPy implementation."""
308
- if src_indices.size == 0:
309
- return [], []
310
-
311
- brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
312
- src_groups = np.split(src_indices, brk)
313
- dst_groups = np.split(dst_indices, brk)
314
-
315
- src_groups = [g.tolist() for g in src_groups]
316
- dst_groups = [g.tolist() for g in dst_groups]
317
-
318
- return src_groups, dst_groups
@@ -523,17 +523,25 @@ class GroupCoordinator:
523
523
  self,
524
524
  input_: torch.Tensor,
525
525
  dim: int = -1,
526
- tensor_list: List[torch.Tensor] = None,
526
+ output_tensor_list: Optional[List[torch.Tensor]] = None,
527
527
  ) -> torch.Tensor:
528
528
  world_size = self.world_size
529
529
  # Bypass the function if we are using only 1 GPU.
530
530
  if world_size == 1:
531
- return input_
531
+ if output_tensor_list is not None:
532
+ logger.warning(
533
+ "Performing in-place all-gather with a group size of 1. "
534
+ "This may be unnecessary; consider bypassing it for better efficiency."
535
+ )
536
+ output_tensor_list[0].copy_(input_)
537
+ return None
538
+ else:
539
+ return input_
532
540
 
533
- if tensor_list is not None:
541
+ if output_tensor_list is not None:
534
542
  # TODO(ch-wan): support other backends
535
543
  return torch.distributed.all_gather(
536
- tensor_list, input_, group=self.device_group
544
+ output_tensor_list, input_, group=self.device_group
537
545
  )
538
546
 
539
547
  assert (
@@ -37,7 +37,6 @@ setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
37
37
  import torch
38
38
  import uvloop
39
39
 
40
- from sglang.srt.code_completion_parser import load_completion_template_for_openai_api
41
40
  from sglang.srt.entrypoints.EngineBase import EngineBase
42
41
  from sglang.srt.managers.data_parallel_controller import (
43
42
  run_data_parallel_controller_process,
@@ -58,11 +57,8 @@ from sglang.srt.managers.io_struct import (
58
57
  UpdateWeightsFromTensorReqInput,
59
58
  )
60
59
  from sglang.srt.managers.scheduler import run_scheduler_process
60
+ from sglang.srt.managers.template_manager import TemplateManager
61
61
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
62
- from sglang.srt.openai_api.adapter import (
63
- guess_chat_template_name_from_model_path,
64
- load_chat_template_for_openai_api,
65
- )
66
62
  from sglang.srt.server_args import PortArgs, ServerArgs
67
63
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
68
64
  from sglang.srt.utils import (
@@ -123,12 +119,13 @@ class Engine(EngineBase):
123
119
  logger.info(f"{server_args=}")
124
120
 
125
121
  # Launch subprocesses
126
- tokenizer_manager, scheduler_info = _launch_subprocesses(
122
+ tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
127
123
  server_args=server_args,
128
124
  port_args=port_args,
129
125
  )
130
126
  self.server_args = server_args
131
127
  self.tokenizer_manager = tokenizer_manager
128
+ self.template_manager = template_manager
132
129
  self.scheduler_info = scheduler_info
133
130
 
134
131
  context = zmq.Context(2)
@@ -175,7 +172,7 @@ class Engine(EngineBase):
175
172
  """
176
173
  if self.server_args.enable_dp_attention:
177
174
  if data_parallel_rank is None:
178
- logger.info("data_parallel_rank not provided, using default dispatch")
175
+ logger.debug("data_parallel_rank not provided, using default dispatch")
179
176
  elif data_parallel_rank < 0:
180
177
  raise ValueError("data_parallel_rank must be non-negative")
181
178
  elif data_parallel_rank >= self.server_args.dp_size:
@@ -258,7 +255,7 @@ class Engine(EngineBase):
258
255
 
259
256
  if self.server_args.enable_dp_attention:
260
257
  if data_parallel_rank is None:
261
- logger.info("data_parallel_rank not provided, using default dispatch")
258
+ logger.debug("data_parallel_rank not provided, using default dispatch")
262
259
  elif data_parallel_rank < 0:
263
260
  raise ValueError("data_parallel_rank must be non-negative")
264
261
  elif data_parallel_rank >= self.server_args.dp_size:
@@ -327,6 +324,20 @@ class Engine(EngineBase):
327
324
  generator = self.tokenizer_manager.generate_request(obj, None)
328
325
  return await generator.__anext__()
329
326
 
327
+ def rerank(
328
+ self,
329
+ prompt: Union[List[List[str]]],
330
+ ) -> Dict:
331
+ """
332
+ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
333
+ Please refer to `EmbeddingReqInput` for the documentation.
334
+ """
335
+ obj = EmbeddingReqInput(text=prompt, is_cross_encoder_request=True)
336
+ loop = asyncio.get_event_loop()
337
+ generator = self.tokenizer_manager.generate_request(obj, None)
338
+ ret = loop.run_until_complete(generator.__anext__())
339
+ return ret
340
+
330
341
  def shutdown(self):
331
342
  """Shutdown the engine"""
332
343
  kill_process_tree(os.getpid(), include_parent=False)
@@ -465,17 +476,15 @@ class Engine(EngineBase):
465
476
  self.tokenizer_manager.get_weights_by_name(obj, None)
466
477
  )
467
478
 
468
- def release_memory_occupation(self):
469
- """Release GPU occupation temporarily."""
470
- obj = ReleaseMemoryOccupationReqInput()
479
+ def release_memory_occupation(self, tags: Optional[List[str]] = None):
480
+ obj = ReleaseMemoryOccupationReqInput(tags=tags)
471
481
  loop = asyncio.get_event_loop()
472
482
  return loop.run_until_complete(
473
483
  self.tokenizer_manager.release_memory_occupation(obj, None)
474
484
  )
475
485
 
476
- def resume_memory_occupation(self):
477
- """Resume GPU occupation."""
478
- obj = ResumeMemoryOccupationReqInput()
486
+ def resume_memory_occupation(self, tags: Optional[List[str]] = None):
487
+ obj = ResumeMemoryOccupationReqInput(tags=tags)
479
488
  loop = asyncio.get_event_loop()
480
489
  return loop.run_until_complete(
481
490
  self.tokenizer_manager.resume_memory_occupation(obj, None)
@@ -605,7 +614,7 @@ def _set_envs_and_config(server_args: ServerArgs):
605
614
  if _is_cuda:
606
615
  assert_pkg_version(
607
616
  "sgl-kernel",
608
- "0.1.7",
617
+ "0.1.9",
609
618
  "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
610
619
  )
611
620
 
@@ -635,7 +644,7 @@ def _set_envs_and_config(server_args: ServerArgs):
635
644
 
636
645
  def _launch_subprocesses(
637
646
  server_args: ServerArgs, port_args: Optional[PortArgs] = None
638
- ) -> Tuple[TokenizerManager, Dict]:
647
+ ) -> Tuple[TokenizerManager, TemplateManager, Dict]:
639
648
  """
640
649
  Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
641
650
  """
@@ -656,11 +665,9 @@ def _launch_subprocesses(
656
665
 
657
666
  scheduler_procs = []
658
667
  if server_args.dp_size == 1:
659
- # Launch tensor parallel scheduler processes
660
668
  memory_saver_adapter = TorchMemorySaverAdapter.create(
661
669
  enable=server_args.enable_memory_saver
662
670
  )
663
-
664
671
  scheduler_pipe_readers = []
665
672
 
666
673
  nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
@@ -696,6 +703,7 @@ def _launch_subprocesses(
696
703
  writer,
697
704
  ),
698
705
  )
706
+
699
707
  with memory_saver_adapter.configure_subprocess():
700
708
  proc.start()
701
709
  scheduler_procs.append(proc)
@@ -721,7 +729,7 @@ def _launch_subprocesses(
721
729
 
722
730
  if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
723
731
  # When using `Engine` as a Python API, we don't want to block here.
724
- return None, None
732
+ return None, None, None
725
733
 
726
734
  launch_dummy_health_check_server(server_args.host, server_args.port)
727
735
 
@@ -730,7 +738,7 @@ def _launch_subprocesses(
730
738
  logger.error(
731
739
  f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
732
740
  )
733
- return None, None
741
+ return None, None, None
734
742
 
735
743
  # Launch detokenizer process
736
744
  detoken_proc = mp.Process(
@@ -744,15 +752,15 @@ def _launch_subprocesses(
744
752
 
745
753
  # Launch tokenizer process
746
754
  tokenizer_manager = TokenizerManager(server_args, port_args)
747
- if server_args.chat_template:
748
- load_chat_template_for_openai_api(
749
- tokenizer_manager, server_args.chat_template, server_args.model_path
750
- )
751
- else:
752
- guess_chat_template_name_from_model_path(server_args.model_path)
753
755
 
754
- if server_args.completion_template:
755
- load_completion_template_for_openai_api(server_args.completion_template)
756
+ # Initialize templates
757
+ template_manager = TemplateManager()
758
+ template_manager.initialize_templates(
759
+ tokenizer_manager=tokenizer_manager,
760
+ model_path=server_args.model_path,
761
+ chat_template=server_args.chat_template,
762
+ completion_template=server_args.completion_template,
763
+ )
756
764
 
757
765
  # Wait for the model to finish loading
758
766
  scheduler_infos = []
@@ -776,4 +784,4 @@ def _launch_subprocesses(
776
784
  # Assume all schedulers have the same scheduler_info
777
785
  scheduler_info = scheduler_infos[0]
778
786
  tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
779
- return tokenizer_manager, scheduler_info
787
+ return tokenizer_manager, template_manager, scheduler_info