sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_serving.py +1 -1
  4. sglang/lang/interpreter.py +40 -1
  5. sglang/lang/ir.py +27 -0
  6. sglang/math_utils.py +8 -0
  7. sglang/srt/configs/model_config.py +6 -0
  8. sglang/srt/conversation.py +6 -0
  9. sglang/srt/disaggregation/base/__init__.py +1 -1
  10. sglang/srt/disaggregation/base/conn.py +25 -11
  11. sglang/srt/disaggregation/common/__init__.py +5 -1
  12. sglang/srt/disaggregation/common/utils.py +42 -0
  13. sglang/srt/disaggregation/decode.py +196 -51
  14. sglang/srt/disaggregation/fake/__init__.py +1 -1
  15. sglang/srt/disaggregation/fake/conn.py +15 -9
  16. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +18 -13
  18. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  19. sglang/srt/disaggregation/nixl/conn.py +17 -12
  20. sglang/srt/disaggregation/prefill.py +128 -43
  21. sglang/srt/disaggregation/utils.py +127 -123
  22. sglang/srt/entrypoints/engine.py +15 -1
  23. sglang/srt/entrypoints/http_server.py +13 -2
  24. sglang/srt/eplb_simulator/__init__.py +1 -0
  25. sglang/srt/eplb_simulator/reader.py +51 -0
  26. sglang/srt/layers/activation.py +19 -0
  27. sglang/srt/layers/attention/aiter_backend.py +15 -2
  28. sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
  29. sglang/srt/layers/attention/flashattention_backend.py +53 -64
  30. sglang/srt/layers/attention/flashinfer_backend.py +1 -2
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
  32. sglang/srt/layers/attention/flashmla_backend.py +2 -10
  33. sglang/srt/layers/attention/triton_backend.py +119 -119
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  35. sglang/srt/layers/attention/vision.py +51 -24
  36. sglang/srt/layers/communicator.py +23 -5
  37. sglang/srt/layers/linear.py +0 -4
  38. sglang/srt/layers/logits_processor.py +0 -12
  39. sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
  40. sglang/srt/layers/moe/ep_moe/layer.py +42 -32
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  42. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
  43. sglang/srt/layers/moe/topk.py +16 -8
  44. sglang/srt/layers/pooler.py +56 -0
  45. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  46. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  47. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  48. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  49. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  50. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  51. sglang/srt/layers/radix_attention.py +2 -3
  52. sglang/srt/lora/lora_manager.py +79 -34
  53. sglang/srt/lora/mem_pool.py +4 -5
  54. sglang/srt/managers/cache_controller.py +2 -1
  55. sglang/srt/managers/io_struct.py +28 -4
  56. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  57. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  58. sglang/srt/managers/schedule_batch.py +39 -6
  59. sglang/srt/managers/scheduler.py +73 -17
  60. sglang/srt/managers/tokenizer_manager.py +29 -2
  61. sglang/srt/mem_cache/chunk_cache.py +1 -0
  62. sglang/srt/mem_cache/hiradix_cache.py +4 -2
  63. sglang/srt/mem_cache/memory_pool.py +111 -407
  64. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  65. sglang/srt/mem_cache/radix_cache.py +36 -12
  66. sglang/srt/model_executor/cuda_graph_runner.py +122 -55
  67. sglang/srt/model_executor/forward_batch_info.py +14 -5
  68. sglang/srt/model_executor/model_runner.py +6 -6
  69. sglang/srt/model_loader/loader.py +8 -1
  70. sglang/srt/models/bert.py +113 -13
  71. sglang/srt/models/deepseek_v2.py +113 -155
  72. sglang/srt/models/internvl.py +46 -102
  73. sglang/srt/models/roberta.py +117 -9
  74. sglang/srt/models/vila.py +305 -0
  75. sglang/srt/openai_api/adapter.py +162 -4
  76. sglang/srt/openai_api/protocol.py +37 -1
  77. sglang/srt/sampling/sampling_batch_info.py +24 -0
  78. sglang/srt/sampling/sampling_params.py +2 -0
  79. sglang/srt/server_args.py +318 -233
  80. sglang/srt/speculative/build_eagle_tree.py +1 -1
  81. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
  82. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
  83. sglang/srt/speculative/eagle_utils.py +389 -109
  84. sglang/srt/speculative/eagle_worker.py +134 -43
  85. sglang/srt/two_batch_overlap.py +4 -2
  86. sglang/srt/utils.py +58 -0
  87. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  88. sglang/test/runners.py +38 -3
  89. sglang/test/test_block_fp8.py +1 -0
  90. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  91. sglang/test/test_block_fp8_ep.py +1 -0
  92. sglang/test/test_utils.py +3 -1
  93. sglang/utils.py +9 -0
  94. sglang/version.py +1 -1
  95. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
  96. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
  97. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -14,15 +14,15 @@ import requests
14
14
  import torch
15
15
  import torch.distributed as dist
16
16
 
17
- from sglang.srt.utils import get_ip, get_local_ip_by_remote
17
+ from sglang.srt.utils import get_ip
18
18
 
19
19
  if TYPE_CHECKING:
20
20
  from sglang.srt.managers.schedule_batch import Req
21
21
 
22
- FakeBootstrapHost = "2.2.2.2"
23
-
24
- # env var for testing failure, convert to float explicitly
25
- FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0))
22
+ #########################
23
+ # Constants & Enums
24
+ #########################
25
+ FAKE_BOOTSTRAP_HOST = "2.2.2.2"
26
26
 
27
27
 
28
28
  class DisaggregationMode(Enum):
@@ -31,6 +31,14 @@ class DisaggregationMode(Enum):
31
31
  DECODE = "decode"
32
32
 
33
33
 
34
+ #########################
35
+ # Synchronization
36
+ #########################
37
+
38
+ # env var for testing failure, convert to float explicitly
39
+ FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0))
40
+
41
+
34
42
  def poll_and_all_reduce(pollers, gloo_group):
35
43
  # at a certain prob, the poll is failed to simulate failure
36
44
  if FAILURE_PROB > 0:
@@ -47,6 +55,11 @@ def poll_and_all_reduce(pollers, gloo_group):
47
55
  return tensor_to_reduce.tolist()
48
56
 
49
57
 
58
+ #########################
59
+ # Metadata Buffers
60
+ #########################
61
+
62
+
50
63
  class ReqToMetadataIdxAllocator:
51
64
  """A memory pool that maps a request to its first output token location."""
52
65
 
@@ -70,6 +83,91 @@ class ReqToMetadataIdxAllocator:
70
83
  self.free_slots.append(free_index)
71
84
 
72
85
 
86
+ class MetadataBuffers:
87
+ def __init__(self, size: int, max_top_logprobs_num: int = 128):
88
+ # TODO: abort top_logprobs_num > 128 in PD
89
+
90
+ # We transfer the metadata of first output token to decode
91
+ # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
92
+ self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu")
93
+ self.output_token_logprobs_val = torch.zeros(
94
+ (size, 16), dtype=torch.float32, device="cpu"
95
+ )
96
+ self.output_token_logprobs_idx = torch.zeros(
97
+ (size, 16), dtype=torch.int32, device="cpu"
98
+ )
99
+ self.output_top_logprobs_val = torch.zeros(
100
+ (size, max_top_logprobs_num), dtype=torch.float32, device="cpu"
101
+ )
102
+ self.output_top_logprobs_idx = torch.zeros(
103
+ (size, max_top_logprobs_num), dtype=torch.int32, device="cpu"
104
+ )
105
+
106
+ def get_buf_infos(self):
107
+ ptrs = [
108
+ self.output_ids.data_ptr(),
109
+ self.output_token_logprobs_val.data_ptr(),
110
+ self.output_token_logprobs_idx.data_ptr(),
111
+ self.output_top_logprobs_val.data_ptr(),
112
+ self.output_top_logprobs_idx.data_ptr(),
113
+ ]
114
+ data_lens = [
115
+ self.output_ids.nbytes,
116
+ self.output_token_logprobs_val.nbytes,
117
+ self.output_token_logprobs_idx.nbytes,
118
+ self.output_top_logprobs_val.nbytes,
119
+ self.output_top_logprobs_idx.nbytes,
120
+ ]
121
+ item_lens = [
122
+ self.output_ids[0].nbytes,
123
+ self.output_token_logprobs_val[0].nbytes,
124
+ self.output_token_logprobs_idx[0].nbytes,
125
+ self.output_top_logprobs_val[0].nbytes,
126
+ self.output_top_logprobs_idx[0].nbytes,
127
+ ]
128
+ return ptrs, data_lens, item_lens
129
+
130
+ def get_buf(self, idx: int):
131
+ return (
132
+ self.output_ids[idx],
133
+ self.output_token_logprobs_val[idx],
134
+ self.output_token_logprobs_idx[idx],
135
+ self.output_top_logprobs_val[idx],
136
+ self.output_top_logprobs_idx[idx],
137
+ )
138
+
139
+ def set_buf(self, req: Req):
140
+
141
+ self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
142
+ if req.return_logprob:
143
+ if req.output_token_logprobs_val: # not none or empty list
144
+ self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
145
+ req.output_token_logprobs_val[0]
146
+ )
147
+ if req.output_token_logprobs_idx: # not none or empty list
148
+ self.output_token_logprobs_idx[req.metadata_buffer_index][0] = (
149
+ req.output_token_logprobs_idx[0]
150
+ )
151
+
152
+ if req.output_top_logprobs_val: # not none or empty list
153
+ self.output_top_logprobs_val[req.metadata_buffer_index][
154
+ : len(req.output_top_logprobs_val[0])
155
+ ] = torch.tensor(
156
+ req.output_top_logprobs_val[0], dtype=torch.float32, device="cpu"
157
+ )
158
+ if req.output_top_logprobs_idx: # not none or empty list
159
+ self.output_top_logprobs_idx[req.metadata_buffer_index][
160
+ : len(req.output_top_logprobs_idx[0])
161
+ ] = torch.tensor(
162
+ req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
163
+ )
164
+
165
+
166
+ #########################
167
+ # Transfer Backend
168
+ #########################
169
+
170
+
73
171
  class TransferBackend(Enum):
74
172
  MOONCAKE = "mooncake"
75
173
  NIXL = "nixl"
@@ -77,6 +175,7 @@ class TransferBackend(Enum):
77
175
 
78
176
 
79
177
  class KVClassType(Enum):
178
+ KVARGS = "kvargs"
80
179
  MANAGER = "manager"
81
180
  SENDER = "sender"
82
181
  RECEIVER = "receiver"
@@ -87,6 +186,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
87
186
  from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
88
187
 
89
188
  if transfer_backend == TransferBackend.MOONCAKE:
189
+ from sglang.srt.disaggregation.base import KVArgs
90
190
  from sglang.srt.disaggregation.mooncake import (
91
191
  MooncakeKVBootstrapServer,
92
192
  MooncakeKVManager,
@@ -95,13 +195,15 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
95
195
  )
96
196
 
97
197
  class_mapping = {
198
+ KVClassType.KVARGS: KVArgs,
98
199
  KVClassType.MANAGER: MooncakeKVManager,
99
200
  KVClassType.SENDER: MooncakeKVSender,
100
201
  KVClassType.RECEIVER: (MooncakeKVReceiver),
101
202
  KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
102
203
  }
103
204
  return class_mapping.get(class_type)
104
- if transfer_backend == TransferBackend.NIXL:
205
+ elif transfer_backend == TransferBackend.NIXL:
206
+ from sglang.srt.disaggregation.base import KVArgs
105
207
  from sglang.srt.disaggregation.nixl import (
106
208
  NixlKVBootstrapServer,
107
209
  NixlKVManager,
@@ -110,16 +212,19 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
110
212
  )
111
213
 
112
214
  class_mapping = {
215
+ KVClassType.KVARGS: KVArgs,
113
216
  KVClassType.MANAGER: NixlKVManager,
114
217
  KVClassType.SENDER: NixlKVSender,
115
218
  KVClassType.RECEIVER: (NixlKVReceiver),
116
219
  KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
117
220
  }
118
221
  return class_mapping.get(class_type)
119
- if transfer_backend == TransferBackend.FAKE:
222
+ elif transfer_backend == TransferBackend.FAKE:
223
+ from sglang.srt.disaggregation.base import KVArgs
120
224
  from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
121
225
 
122
226
  class_mapping = {
227
+ KVClassType.KVARGS: KVArgs,
123
228
  KVClassType.SENDER: FakeKVSender,
124
229
  KVClassType.RECEIVER: (FakeKVReceiver),
125
230
  }
@@ -128,6 +233,11 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
128
233
  raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
129
234
 
130
235
 
236
+ #########################
237
+ # KV Pages
238
+ #########################
239
+
240
+
131
241
  def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
132
242
  # 1. The page is guaranteed to be full except the last page.
133
243
  # 2. page index = kv_index // page_size
@@ -143,6 +253,11 @@ def kv_to_page_num(num_kv_indices: int, page_size: int):
143
253
  return (num_kv_indices + page_size - 1) // page_size
144
254
 
145
255
 
256
+ #########################
257
+ # PDLB Registry
258
+ #########################
259
+
260
+
146
261
  @dataclasses.dataclass
147
262
  class PDRegistryRequest:
148
263
  """A request to register a machine itself to the LB."""
@@ -181,6 +296,11 @@ def register_disaggregation_server(
181
296
  )
182
297
 
183
298
 
299
+ #########################
300
+ # Misc
301
+ #########################
302
+
303
+
184
304
  def is_mla_backend(target_kv_pool) -> bool:
185
305
  from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
186
306
 
@@ -200,119 +320,3 @@ def prepare_abort(req: Req, error_message: str, status_code=None):
200
320
  req.input_top_logprobs_idx = []
201
321
  req.input_token_ids_logprobs_val = []
202
322
  req.input_token_ids_logprobs_idx = []
203
-
204
-
205
- class MetadataBuffers:
206
- def __init__(self, size: int, max_top_logprobs_num: int = 128):
207
- # TODO: abort top_logprobs_num > 128 in PD
208
-
209
- # We transfer the metadata of first output token to decode
210
- # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
211
- self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu")
212
- self.output_token_logprobs_val = torch.zeros(
213
- (size, 16), dtype=torch.float32, device="cpu"
214
- )
215
- self.output_token_logprobs_idx = torch.zeros(
216
- (size, 16), dtype=torch.int32, device="cpu"
217
- )
218
- self.output_top_logprobs_val = torch.zeros(
219
- (size, max_top_logprobs_num), dtype=torch.float32, device="cpu"
220
- )
221
- self.output_top_logprobs_idx = torch.zeros(
222
- (size, max_top_logprobs_num), dtype=torch.int32, device="cpu"
223
- )
224
-
225
- def get_buf_infos(self):
226
- ptrs = [
227
- self.output_ids.data_ptr(),
228
- self.output_token_logprobs_val.data_ptr(),
229
- self.output_token_logprobs_idx.data_ptr(),
230
- self.output_top_logprobs_val.data_ptr(),
231
- self.output_top_logprobs_idx.data_ptr(),
232
- ]
233
- data_lens = [
234
- self.output_ids.nbytes,
235
- self.output_token_logprobs_val.nbytes,
236
- self.output_token_logprobs_idx.nbytes,
237
- self.output_top_logprobs_val.nbytes,
238
- self.output_top_logprobs_idx.nbytes,
239
- ]
240
- item_lens = [
241
- self.output_ids[0].nbytes,
242
- self.output_token_logprobs_val[0].nbytes,
243
- self.output_token_logprobs_idx[0].nbytes,
244
- self.output_top_logprobs_val[0].nbytes,
245
- self.output_top_logprobs_idx[0].nbytes,
246
- ]
247
- return ptrs, data_lens, item_lens
248
-
249
- def get_buf(self, idx: int):
250
- return (
251
- self.output_ids[idx],
252
- self.output_token_logprobs_val[idx],
253
- self.output_token_logprobs_idx[idx],
254
- self.output_top_logprobs_val[idx],
255
- self.output_top_logprobs_idx[idx],
256
- )
257
-
258
- def set_buf(self, req: Req):
259
-
260
- self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
261
- if req.return_logprob:
262
- if req.output_token_logprobs_val: # not none or empty list
263
- self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
264
- req.output_token_logprobs_val[0]
265
- )
266
- if req.output_token_logprobs_idx: # not none or empty list
267
- self.output_token_logprobs_idx[req.metadata_buffer_index][0] = (
268
- req.output_token_logprobs_idx[0]
269
- )
270
-
271
- if req.output_top_logprobs_val: # not none or empty list
272
- self.output_top_logprobs_val[req.metadata_buffer_index][
273
- : len(req.output_top_logprobs_val[0])
274
- ] = torch.tensor(
275
- req.output_top_logprobs_val[0], dtype=torch.float32, device="cpu"
276
- )
277
- if req.output_top_logprobs_idx: # not none or empty list
278
- self.output_top_logprobs_idx[req.metadata_buffer_index][
279
- : len(req.output_top_logprobs_idx[0])
280
- ] = torch.tensor(
281
- req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
282
- )
283
-
284
-
285
- class FastQueue:
286
- def __init__(self):
287
- self._buf = deque()
288
- self._cond = threading.Condition()
289
-
290
- def put(self, item):
291
- with self._cond:
292
- self._buf.append(item)
293
- # wake up a thread of wait()
294
- self._cond.notify()
295
-
296
- def get(self):
297
- with self._cond:
298
- # if queue is empty ,block until is notified()
299
- while not self._buf:
300
- self._cond.wait()
301
- return self._buf.popleft()
302
-
303
-
304
- def group_concurrent_contiguous(
305
- src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
306
- ) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
307
- """Vectorised NumPy implementation."""
308
- if src_indices.size == 0:
309
- return [], []
310
-
311
- brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
312
- src_groups = np.split(src_indices, brk)
313
- dst_groups = np.split(dst_indices, brk)
314
-
315
- src_groups = [g.tolist() for g in src_groups]
316
- dst_groups = [g.tolist() for g in dst_groups]
317
-
318
- return src_groups, dst_groups
@@ -327,6 +327,20 @@ class Engine(EngineBase):
327
327
  generator = self.tokenizer_manager.generate_request(obj, None)
328
328
  return await generator.__anext__()
329
329
 
330
+ def rerank(
331
+ self,
332
+ prompt: Union[List[List[str]]],
333
+ ) -> Dict:
334
+ """
335
+ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
336
+ Please refer to `EmbeddingReqInput` for the documentation.
337
+ """
338
+ obj = EmbeddingReqInput(text=prompt, is_cross_encoder_request=True)
339
+ loop = asyncio.get_event_loop()
340
+ generator = self.tokenizer_manager.generate_request(obj, None)
341
+ ret = loop.run_until_complete(generator.__anext__())
342
+ return ret
343
+
330
344
  def shutdown(self):
331
345
  """Shutdown the engine"""
332
346
  kill_process_tree(os.getpid(), include_parent=False)
@@ -605,7 +619,7 @@ def _set_envs_and_config(server_args: ServerArgs):
605
619
  if _is_cuda:
606
620
  assert_pkg_version(
607
621
  "sgl-kernel",
608
- "0.1.7",
622
+ "0.1.9",
609
623
  "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
610
624
  )
611
625
 
@@ -43,7 +43,7 @@ from fastapi.middleware.cors import CORSMiddleware
43
43
  from fastapi.responses import ORJSONResponse, Response, StreamingResponse
44
44
 
45
45
  from sglang.srt.disaggregation.utils import (
46
- FakeBootstrapHost,
46
+ FAKE_BOOTSTRAP_HOST,
47
47
  register_disaggregation_server,
48
48
  )
49
49
  from sglang.srt.entrypoints.engine import _launch_subprocesses
@@ -67,6 +67,7 @@ from sglang.srt.managers.io_struct import (
67
67
  UpdateWeightFromDiskReqInput,
68
68
  UpdateWeightsFromDistributedReqInput,
69
69
  UpdateWeightsFromTensorReqInput,
70
+ V1RerankReqInput,
70
71
  VertexGenerateReqInput,
71
72
  )
72
73
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -79,6 +80,7 @@ from sglang.srt.openai_api.adapter import (
79
80
  v1_delete_file,
80
81
  v1_embeddings,
81
82
  v1_files_create,
83
+ v1_rerank,
82
84
  v1_retrieve_batch,
83
85
  v1_retrieve_file,
84
86
  v1_retrieve_file_content,
@@ -328,6 +330,15 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
328
330
  return _create_error_response(e)
329
331
 
330
332
 
333
+ @app.api_route("/v1/rerank", methods=["POST", "PUT"])
334
+ async def v1_rerank_request(obj: V1RerankReqInput, raw_request: Request):
335
+ try:
336
+ ret = await v1_rerank(_global_state.tokenizer_manager, obj, raw_request)
337
+ return ret
338
+ except ValueError as e:
339
+ return _create_error_response(e)
340
+
341
+
331
342
  @app.api_route("/flush_cache", methods=["GET", "POST"])
332
343
  async def flush_cache():
333
344
  """Flush the radix cache."""
@@ -878,7 +889,7 @@ def _wait_and_warmup(
878
889
  "max_new_tokens": 8,
879
890
  "ignore_eos": True,
880
891
  },
881
- "bootstrap_host": [FakeBootstrapHost] * server_args.dp_size,
892
+ "bootstrap_host": [FAKE_BOOTSTRAP_HOST] * server_args.dp_size,
882
893
  # This is a hack to ensure fake transfer is enabled during prefill warmup
883
894
  # ensure each dp rank has a unique bootstrap_room during prefill warmup
884
895
  "bootstrap_room": [
@@ -0,0 +1 @@
1
+ from . import reader
@@ -0,0 +1,51 @@
1
+ from collections import defaultdict
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from tqdm import tqdm
6
+
7
+ from sglang.srt.managers.expert_distribution import (
8
+ _convert_global_physical_count_to_logical_count,
9
+ )
10
+
11
+ convert_global_physical_count_to_logical_count = (
12
+ _convert_global_physical_count_to_logical_count
13
+ )
14
+
15
+
16
+ def read_mode_per_pass(dir_data: Path):
17
+ """Read data from ExpertDistributionRecorder when recorded with mode `per_pass`"""
18
+
19
+ # gpc := global_physical_count
20
+ gpc_of_forward_pass_and_rank = defaultdict(lambda: defaultdict())
21
+ for path in tqdm(list(dir_data.glob("*.pt"))):
22
+ data_pack = torch.load(path, weights_only=True)
23
+ last_physical_to_logical_map = data_pack["last_physical_to_logical_map"]
24
+ for record in data_pack["records"]:
25
+ forward_pass_id = record["forward_pass_id"]
26
+ rank = record["rank"]
27
+ assert (
28
+ gpc_of_forward_pass_and_rank[forward_pass_id].get(rank) is None
29
+ ), f"Duplicated {forward_pass_id=} {rank=}"
30
+ gpc_of_forward_pass_and_rank[forward_pass_id][rank] = record[
31
+ "global_physical_count"
32
+ ]
33
+
34
+ forward_pass_ids = sorted(gpc_of_forward_pass_and_rank.keys())
35
+ print(f"Make {forward_pass_ids=} into array")
36
+
37
+ items = []
38
+ for forward_pass_id, gpc_of_rank in sorted(gpc_of_forward_pass_and_rank.items()):
39
+ gpc_of_rank_tensor = torch.stack(
40
+ [gpc for rank, gpc in sorted(gpc_of_rank.items())]
41
+ ).sum(dim=0)
42
+ items.append(gpc_of_rank_tensor)
43
+
44
+ gpc_of_forward_pass = torch.stack(items)
45
+ print(f"{gpc_of_forward_pass.shape=}")
46
+
47
+ return dict(
48
+ global_physical_count_of_forward_pass=gpc_of_forward_pass,
49
+ last_physical_to_logical_map=last_physical_to_logical_map,
50
+ forward_pass_ids=forward_pass_ids,
51
+ )
@@ -20,6 +20,7 @@ from typing import Optional
20
20
  import torch
21
21
  import torch.nn as nn
22
22
  import torch.nn.functional as F
23
+ from transformers import PretrainedConfig
23
24
 
24
25
  from sglang.srt.custom_op import CustomOp
25
26
  from sglang.srt.distributed import (
@@ -29,6 +30,7 @@ from sglang.srt.distributed import (
29
30
  )
30
31
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
31
32
  from sglang.srt.utils import is_cuda, set_weight_attrs
33
+ from sglang.utils import resolve_obj_by_qualname
32
34
 
33
35
  _is_cuda = is_cuda()
34
36
 
@@ -165,6 +167,23 @@ def get_act_fn(
165
167
  return act_fn
166
168
 
167
169
 
170
+ def get_cross_encoder_activation_function(config: PretrainedConfig):
171
+ if (
172
+ hasattr(config, "sbert_ce_default_activation_function")
173
+ and config.sbert_ce_default_activation_function is not None
174
+ ):
175
+
176
+ function_name = config.sbert_ce_default_activation_function
177
+ assert function_name.startswith("torch.nn.modules."), (
178
+ "Loading of activation functions is restricted to "
179
+ "torch.nn.modules for security reasons"
180
+ )
181
+ return resolve_obj_by_qualname(function_name)()
182
+ else:
183
+ # adapt bge-reranker
184
+ return nn.Identity()
185
+
186
+
168
187
  if not _is_cuda:
169
188
  logger.info(
170
189
  "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
@@ -717,6 +717,11 @@ class AiterIndicesUpdaterPrefill:
717
717
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
718
718
  self.update = self.update_single_wrapper
719
719
 
720
+ # get the last index of the pool
721
+ self.pool_size = (
722
+ model_runner.token_to_kv_pool.size + model_runner.token_to_kv_pool.page_size
723
+ ) - 1
724
+
720
725
  self.kv_indices = None
721
726
  self.max_q_len = 0
722
727
  self.max_kv_len = 0
@@ -754,8 +759,16 @@ class AiterIndicesUpdaterPrefill:
754
759
  # Normal extend
755
760
  kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
756
761
  kv_indptr = kv_indptr[: bs + 1]
757
- kv_indices = torch.empty(
758
- paged_kernel_lens_sum + 256,
762
+
763
+ # (TODO: Kk) WA - CI test_moe_eval_accuracy_large.py
764
+ # mha_batch_prefill reads 128 data to do computatoin
765
+ # if real data is not long enough then original padding value 0 is used
766
+ # but the 0 location will be made nan (noqa) in cuda graph capture mode
767
+ # this will cause the output tensor value becomes nan
768
+ # WA is to assure that last index of pool not changed
769
+ kv_indices = torch.full(
770
+ (paged_kernel_lens_sum + 128,),
771
+ self.pool_size,
759
772
  dtype=torch.int32,
760
773
  device=req_pool_indices.device,
761
774
  )
@@ -11,8 +11,6 @@ from typing import TYPE_CHECKING, Optional, Union
11
11
  import torch
12
12
  import triton
13
13
 
14
- from sglang.global_config import global_config
15
- from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
16
14
  from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
17
15
  from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
18
16
  from sglang.srt.layers.dp_attention import get_attention_tp_size
@@ -22,7 +20,6 @@ from sglang.srt.utils import is_cuda
22
20
  if TYPE_CHECKING:
23
21
  from sglang.srt.layers.radix_attention import RadixAttention
24
22
  from sglang.srt.model_executor.model_runner import ModelRunner
25
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
26
23
  from sglang.srt.speculative.spec_info import SpecInfo
27
24
 
28
25
  _is_cuda = is_cuda()
@@ -108,7 +105,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
108
105
  PAGE_SIZE,
109
106
  )
110
107
  workspace_size = cutlass_mla_get_workspace_size(
111
- max_seqlen_pad * PAGE_SIZE, bs
108
+ max_seqlen_pad * PAGE_SIZE, bs, num_kv_splits=1
112
109
  )
113
110
  workspace = torch.empty(
114
111
  workspace_size, device="cuda", dtype=torch.uint8
@@ -138,7 +135,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
138
135
  cuda_graph_kv_indices = block_kv_indices
139
136
 
140
137
  workspace_size = cutlass_mla_get_workspace_size(
141
- cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs
138
+ cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs, num_kv_splits=1
142
139
  )
143
140
  self.cuda_graph_mla_workspace = torch.empty(
144
141
  workspace_size, device="cuda", dtype=torch.uint8
@@ -233,29 +230,55 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
233
230
  layer: RadixAttention,
234
231
  forward_batch: ForwardBatch,
235
232
  save_kv_cache: bool = True,
233
+ # For multi-head latent attention
234
+ q_rope: Optional[torch.Tensor] = None,
235
+ k_rope: Optional[torch.Tensor] = None,
236
236
  ):
237
237
  cache_loc = forward_batch.out_cache_loc
238
238
 
239
239
  if k is not None:
240
240
  assert v is not None
241
241
  if save_kv_cache:
242
- forward_batch.token_to_kv_pool.set_kv_buffer(
243
- layer,
244
- cache_loc,
245
- k,
246
- v,
247
- )
248
- bs = forward_batch.batch_size
249
- k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
242
+ if k_rope is not None:
243
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
244
+ layer,
245
+ cache_loc,
246
+ k,
247
+ k_rope,
248
+ )
249
+ else:
250
+ forward_batch.token_to_kv_pool.set_kv_buffer(
251
+ layer,
252
+ cache_loc,
253
+ k,
254
+ v,
255
+ )
256
+
257
+ # Reshape inputs
258
+ if q_rope is not None:
259
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
260
+ q_rope = q_rope.view(
261
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
262
+ )
263
+ else:
264
+ reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
265
+ q_nope = reshaped_q[:, :, : layer.v_head_dim]
266
+ q_rope = reshaped_q[:, :, layer.v_head_dim :]
250
267
 
251
- reshape_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
268
+ q_nope = q_nope.to(self.q_data_type)
269
+ q_rope = q_rope.to(self.q_data_type)
270
+
271
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
252
272
 
253
273
  o = cutlass_mla_decode(
254
- q_nope_and_q_pe=reshape_q.to(self.q_data_type),
274
+ q_nope=q_nope,
275
+ q_pe=q_rope,
255
276
  kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim),
256
277
  seq_lens=forward_batch.seq_lens.to(torch.int32),
257
278
  page_table=self.forward_metadata.block_kv_indices,
258
279
  workspace=self.forward_metadata.workspace,
280
+ sm_scale=layer.scaling,
281
+ num_kv_splits=1,
259
282
  )
260
283
 
261
284
  return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)