sglang 0.4.5.post2__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. sglang/bench_serving.py +3 -2
  2. sglang/compile_deep_gemm.py +136 -0
  3. sglang/lang/backend/openai.py +5 -1
  4. sglang/lang/backend/runtime_endpoint.py +5 -1
  5. sglang/srt/configs/model_config.py +4 -1
  6. sglang/srt/constrained/xgrammar_backend.py +1 -0
  7. sglang/srt/disaggregation/decode.py +43 -0
  8. sglang/srt/disaggregation/mini_lb.py +69 -8
  9. sglang/srt/disaggregation/mooncake/conn.py +1 -1
  10. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  11. sglang/srt/disaggregation/nixl/conn.py +622 -0
  12. sglang/srt/disaggregation/prefill.py +100 -16
  13. sglang/srt/disaggregation/utils.py +17 -0
  14. sglang/srt/entrypoints/engine.py +4 -0
  15. sglang/srt/entrypoints/http_server.py +3 -7
  16. sglang/srt/function_call_parser.py +60 -0
  17. sglang/srt/layers/activation.py +2 -2
  18. sglang/srt/layers/attention/flashattention_backend.py +781 -150
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
  21. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  22. sglang/srt/layers/dp_attention.py +1 -1
  23. sglang/srt/layers/layernorm.py +19 -4
  24. sglang/srt/layers/moe/ep_moe/layer.py +2 -0
  25. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  26. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  27. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  28. sglang/srt/layers/quantization/deep_gemm.py +378 -0
  29. sglang/srt/layers/quantization/fp8_kernel.py +7 -38
  30. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  31. sglang/srt/layers/quantization/gptq.py +13 -7
  32. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  33. sglang/srt/layers/quantization/w8a8_int8.py +3 -3
  34. sglang/srt/layers/rotary_embedding.py +6 -6
  35. sglang/srt/layers/sampler.py +2 -2
  36. sglang/srt/managers/data_parallel_controller.py +7 -1
  37. sglang/srt/managers/io_struct.py +14 -3
  38. sglang/srt/managers/schedule_batch.py +13 -0
  39. sglang/srt/managers/scheduler.py +16 -6
  40. sglang/srt/managers/tokenizer_manager.py +115 -29
  41. sglang/srt/managers/tp_worker.py +1 -0
  42. sglang/srt/mem_cache/hiradix_cache.py +40 -32
  43. sglang/srt/mem_cache/memory_pool.py +31 -13
  44. sglang/srt/model_executor/cuda_graph_runner.py +13 -8
  45. sglang/srt/model_executor/model_runner.py +19 -4
  46. sglang/srt/models/deepseek_v2.py +9 -6
  47. sglang/srt/models/minicpm3.py +2 -2
  48. sglang/srt/models/minicpmo.py +17 -6
  49. sglang/srt/openai_api/adapter.py +71 -4
  50. sglang/srt/openai_api/protocol.py +6 -1
  51. sglang/srt/server_args.py +52 -40
  52. sglang/srt/speculative/build_eagle_tree.py +2 -2
  53. sglang/srt/speculative/eagle_utils.py +2 -2
  54. sglang/srt/speculative/eagle_worker.py +2 -7
  55. sglang/srt/utils.py +46 -5
  56. sglang/test/test_utils.py +3 -1
  57. sglang/version.py +1 -1
  58. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +3 -3
  59. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +62 -57
  60. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +0 -0
  61. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
  62. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server
20
20
  from __future__ import annotations
21
21
 
22
22
  import logging
23
+ from collections import deque
23
24
  from typing import TYPE_CHECKING, List, Optional
24
25
 
25
26
  import torch
@@ -204,6 +205,40 @@ class SchedulerDisaggregationPrefillMixin:
204
205
  # Otherwise, it hangs under high concurrency
205
206
  self.running_batch.batch_is_full = False
206
207
 
208
+ @torch.no_grad()
209
+ def event_loop_overlap_disagg_prefill(self):
210
+ self.result_queue = deque()
211
+
212
+ while True:
213
+ recv_reqs = self.recv_requests()
214
+ self.process_input_requests(recv_reqs)
215
+ self.waiting_queue.extend(
216
+ self.disagg_prefill_pending_queue.pop_bootstrapped()
217
+ )
218
+ self.process_prefill_chunk()
219
+ batch = self.get_new_batch_prefill()
220
+ self.cur_batch = batch
221
+
222
+ if batch:
223
+ result = self.run_batch(batch)
224
+ self.result_queue.append((batch.copy(), result))
225
+
226
+ if self.last_batch:
227
+ tmp_batch, tmp_result = self.result_queue.popleft()
228
+ self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
229
+
230
+ if len(self.disagg_prefill_inflight_queue) > 0:
231
+ self.process_disagg_prefill_inflight_queue()
232
+
233
+ if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
234
+ self.check_memory()
235
+ self.new_token_ratio = self.init_new_token_ratio
236
+
237
+ self.last_batch = batch
238
+ # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
239
+ # Otherwise, it hangs under high concurrency
240
+ self.running_batch.batch_is_full = False
241
+
207
242
  def process_batch_result_disagg_prefill(
208
243
  self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
209
244
  ) -> None:
@@ -212,7 +247,26 @@ class SchedulerDisaggregationPrefillMixin:
212
247
  Adapted from process_batch_result_prefill
213
248
  """
214
249
 
215
- next_token_ids = result.next_token_ids.tolist()
250
+ (
251
+ logits_output,
252
+ next_token_ids,
253
+ extend_input_len_per_req,
254
+ extend_logprob_start_len_per_req,
255
+ bid,
256
+ ) = (
257
+ result.logits_output,
258
+ result.next_token_ids,
259
+ result.extend_input_len_per_req,
260
+ result.extend_logprob_start_len_per_req,
261
+ result.bid,
262
+ )
263
+
264
+ # Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
265
+ if self.enable_overlap:
266
+ # wait
267
+ _, next_token_ids = self.tp_worker.resolve_batch_result(bid)
268
+ else:
269
+ next_token_ids = result.next_token_ids.tolist()
216
270
 
217
271
  for req, next_token_id in zip(batch.reqs, next_token_ids, strict=True):
218
272
  req: Req
@@ -226,12 +280,8 @@ class SchedulerDisaggregationPrefillMixin:
226
280
  # being chunked reqs' prefill is not finished
227
281
  req.is_chunked -= 1
228
282
 
229
- # TODO: Not sure if this is necessary
230
- if batch.next_batch_sampling_info:
231
- batch.next_batch_sampling_info.update_regex_vocab_mask()
232
- # We need to remove this for overlap schedule.
233
- self.current_stream.synchronize()
234
- batch.next_batch_sampling_info.sampling_info_done.set()
283
+ if self.enable_overlap:
284
+ self.send_kv_chunk(req, end_idx=req.tmp_end_idx)
235
285
 
236
286
  def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
237
287
  """
@@ -276,34 +326,68 @@ class SchedulerDisaggregationPrefillMixin:
276
326
  # only finished requests to running_batch.
277
327
  self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
278
328
  self.tree_cache.cache_unfinished_req(self.chunked_req)
279
- self.send_kv_chunk(self.chunked_req)
329
+ if (
330
+ self.enable_overlap
331
+ ): # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
332
+ self.chunked_req.tmp_end_idx = min(
333
+ len(self.chunked_req.fill_ids),
334
+ len(self.chunked_req.origin_input_ids),
335
+ )
336
+ else:
337
+ self.send_kv_chunk(self.chunked_req)
280
338
  # chunked request keeps its rid but will get a new req_pool_idx
281
339
  self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
282
340
  self.running_batch.batch_is_full = False
283
341
 
284
342
  def send_kv_chunk(
285
- self: Scheduler, req: Req, token_id: Optional[int] = None
343
+ self: Scheduler,
344
+ req: Req,
345
+ token_id: Optional[int] = None,
346
+ end_idx: Optional[int] = None,
286
347
  ) -> None:
287
348
  """
288
349
  Send a prefilled chunk to the decode server
289
350
  """
351
+ page_size = self.token_to_kv_pool_allocator.page_size
290
352
  start_idx = req.start_send_idx
291
- end_idx = min(len(req.fill_ids), len(req.origin_input_ids))
353
+ # if end_idx is specified, use it as the end index of the kv chunk because in overlap schedule,
354
+ # the resolved length is not the same as fill_ids's length
355
+ end_idx = (
356
+ end_idx
357
+ if end_idx is not None
358
+ else min(len(req.fill_ids), len(req.origin_input_ids))
359
+ )
360
+ last_chunk = token_id is not None
361
+
362
+ if (not last_chunk) and (
363
+ end_idx % page_size != 0
364
+ ): # todo: remove the second condition
365
+ # if not the last chunk and the last page is partial, delay the last partial page to the next send
366
+ end_idx = end_idx - end_idx % page_size
292
367
 
293
368
  # Update next start_send_idx
294
369
  req.start_send_idx = end_idx
295
370
 
296
371
  kv_indices = (
297
- self.req_to_token_pool.req_to_token[req.req_pool_idx][start_idx:end_idx]
372
+ self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
298
373
  .cpu()
299
374
  .numpy()
300
375
  )
301
- if token_id is not None:
376
+ if last_chunk is True:
302
377
  self.disagg_prefill_pending_queue.store_prefill_results(
303
378
  req.metadata_buffer_index, token_id
304
379
  )
305
- is_last = token_id is not None
306
- page_indices = kv_to_page_indices(
307
- kv_indices, self.token_to_kv_pool_allocator.page_size
380
+ page_indices = kv_to_page_indices(kv_indices, page_size)
381
+
382
+ page_start_idx = start_idx // page_size
383
+ page_end_idx = page_start_idx + len(page_indices)
384
+
385
+ if len(page_indices) == 0:
386
+ logger.info(
387
+ f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
388
+ )
389
+ return
390
+
391
+ req.disagg_kv_sender.send(
392
+ page_indices, slice(page_start_idx, page_end_idx), last_chunk
308
393
  )
309
- req.disagg_kv_sender.send(page_indices, slice(start_idx, end_idx), is_last)
@@ -47,6 +47,7 @@ class ReqToMetadataIdxAllocator:
47
47
 
48
48
  class TransferBackend(Enum):
49
49
  MOONCAKE = "mooncake"
50
+ NIXL = "nixl"
50
51
  FAKE = "fake"
51
52
 
52
53
 
@@ -73,6 +74,21 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
73
74
  KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
74
75
  }
75
76
  return class_mapping.get(class_type)
77
+ if transfer_backend == TransferBackend.NIXL:
78
+ from sglang.srt.disaggregation.nixl import (
79
+ NixlKVBootstrapServer,
80
+ NixlKVManager,
81
+ NixlKVReceiver,
82
+ NixlKVSender,
83
+ )
84
+
85
+ class_mapping = {
86
+ KVClassType.MANAGER: NixlKVManager,
87
+ KVClassType.SENDER: NixlKVSender,
88
+ KVClassType.RECEIVER: NixlKVReceiver,
89
+ KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
90
+ }
91
+ return class_mapping.get(class_type)
76
92
  raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
77
93
 
78
94
 
@@ -82,6 +98,7 @@ def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
82
98
  # The return vector is kv_indices[::page_size] // page_size
83
99
  if page_size == 1: # shortcut
84
100
  return kv_indices
101
+
85
102
  return kv_indices[::page_size] // page_size
86
103
 
87
104
 
@@ -279,6 +279,10 @@ class Engine(EngineBase):
279
279
  self.shutdown()
280
280
  return False
281
281
 
282
+ def flush_cache(self):
283
+ loop = asyncio.get_event_loop()
284
+ return loop.run_until_complete(self.tokenizer_manager.flush_cache())
285
+
282
286
  def start_profile(self):
283
287
  loop = asyncio.get_event_loop()
284
288
  loop.run_until_complete(self.tokenizer_manager.start_profile())
@@ -25,11 +25,8 @@ import multiprocessing as multiprocessing
25
25
  import os
26
26
  import threading
27
27
  import time
28
- from ast import Mult
29
28
  from http import HTTPStatus
30
- from typing import AsyncIterator, Callable, Dict, Optional, Union
31
-
32
- from sglang.srt.model_executor.model_runner import LocalSerializedTensor
29
+ from typing import AsyncIterator, Callable, Dict, Optional
33
30
 
34
31
  # Fix a bug of Python threading
35
32
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -84,7 +81,6 @@ from sglang.srt.openai_api.protocol import ModelCard, ModelList
84
81
  from sglang.srt.reasoning_parser import ReasoningParser
85
82
  from sglang.srt.server_args import ServerArgs
86
83
  from sglang.srt.utils import (
87
- MultiprocessingSerializer,
88
84
  add_api_key_middleware,
89
85
  add_prometheus_middleware,
90
86
  delete_directory,
@@ -315,11 +311,11 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
315
311
  @app.api_route("/flush_cache", methods=["GET", "POST"])
316
312
  async def flush_cache():
317
313
  """Flush the radix cache."""
318
- _global_state.tokenizer_manager.flush_cache()
314
+ ret = await _global_state.tokenizer_manager.flush_cache()
319
315
  return Response(
320
316
  content="Cache flushed.\nPlease check backend logs for more details. "
321
317
  "(When there are running or waiting requests, the operation will not be performed.)\n",
322
- status_code=200,
318
+ status_code=200 if ret.success else HTTPStatus.BAD_REQUEST,
323
319
  )
324
320
 
325
321
 
@@ -25,6 +25,7 @@ TOOLS_TAG_LIST = [
25
25
  "<tool_call>",
26
26
  "<|python_tag|>",
27
27
  "[TOOL_CALLS]",
28
+ "<|tool▁calls▁begin|>",
28
29
  ]
29
30
 
30
31
 
@@ -477,6 +478,64 @@ class Llama32Detector(BaseFormatDetector):
477
478
  )
478
479
 
479
480
 
481
+ class DeepSeekV3Detector(BaseFormatDetector):
482
+ """
483
+ Detector for DeepSeek models.
484
+ Assumes function call format:
485
+ '<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>
486
+ """
487
+
488
+ def __init__(self):
489
+ super().__init__()
490
+ self.bot_token = "<|tool▁calls▁begin|>"
491
+ self.eot_token = "<|tool▁calls▁end|>"
492
+ self.func_call_regex = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>"
493
+ self.func_detail_regex = r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)\n```<|tool▁call▁end|>"
494
+
495
+ def has_tool_call(self, text: str) -> bool:
496
+ """Check if the text contains a deepseek format tool call."""
497
+ return self.bot_token in text
498
+
499
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
500
+ """
501
+ One-time parsing: Detects and parses tool calls in the provided text.
502
+
503
+ :param text: The complete text to parse.
504
+ :param tools: List of available tools.
505
+ :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
506
+ """
507
+ idx = text.find(self.bot_token)
508
+ normal_text = text[:idx].strip() if idx != -1 else text
509
+ if self.bot_token not in text:
510
+ return StreamingParseResult(normal_text=normal_text, calls=[])
511
+ match_result_list = re.findall(self.func_call_regex, text, re.DOTALL)
512
+ calls = []
513
+ try:
514
+ for match_result in match_result_list:
515
+ # Get function name
516
+ func_detail = re.search(self.func_detail_regex, match_result, re.DOTALL)
517
+ func_name = func_detail.group(2)
518
+ func_args = func_detail.group(3)
519
+ func_args = json.loads(func_args)
520
+ # construct match_result for parse_base_json
521
+ match_result = {"name": func_name, "parameters": func_args}
522
+ calls.extend(self.parse_base_json(match_result, tools))
523
+ return StreamingParseResult(normal_text=normal_text, calls=calls)
524
+ except Exception as e:
525
+ logger.error(f"Error in detect_and_parse: {e}")
526
+ # return the normal text if parsing fails
527
+ return StreamingParseResult(normal_text=text)
528
+
529
+ def structure_info(self) -> _GetInfoFunc:
530
+ return lambda name: StructureInfo(
531
+ begin="<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>"
532
+ + name
533
+ + "\n```json\n",
534
+ end="\n```<|tool▁call▁end|><|tool▁calls▁end|>",
535
+ trigger="<|tool▁calls▁begin|>",
536
+ )
537
+
538
+
480
539
  class MultiFormatParser:
481
540
  def __init__(self, detectors: List[BaseFormatDetector]):
482
541
  """
@@ -543,6 +602,7 @@ class FunctionCallParser:
543
602
  "llama3": Llama32Detector,
544
603
  "qwen25": Qwen25Detector,
545
604
  "mistral": MistralDetector,
605
+ "deepseekv3": DeepSeekV3Detector,
546
606
  }
547
607
 
548
608
  def __init__(self, tools: List[Tool], tool_call_parser: str):
@@ -28,9 +28,9 @@ from sglang.srt.distributed import (
28
28
  get_tensor_model_parallel_world_size,
29
29
  )
30
30
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
31
- from sglang.srt.utils import is_cuda_available, set_weight_attrs
31
+ from sglang.srt.utils import is_cuda, set_weight_attrs
32
32
 
33
- _is_cuda = is_cuda_available()
33
+ _is_cuda = is_cuda()
34
34
 
35
35
  if _is_cuda:
36
36
  from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul