sglang 0.4.5.post2__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. sglang/bench_serving.py +3 -2
  2. sglang/compile_deep_gemm.py +136 -0
  3. sglang/lang/backend/openai.py +5 -1
  4. sglang/lang/backend/runtime_endpoint.py +5 -1
  5. sglang/srt/configs/model_config.py +4 -1
  6. sglang/srt/constrained/xgrammar_backend.py +1 -0
  7. sglang/srt/disaggregation/decode.py +43 -0
  8. sglang/srt/disaggregation/mini_lb.py +69 -8
  9. sglang/srt/disaggregation/mooncake/conn.py +1 -1
  10. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  11. sglang/srt/disaggregation/nixl/conn.py +622 -0
  12. sglang/srt/disaggregation/prefill.py +100 -16
  13. sglang/srt/disaggregation/utils.py +17 -0
  14. sglang/srt/entrypoints/engine.py +4 -0
  15. sglang/srt/entrypoints/http_server.py +3 -7
  16. sglang/srt/function_call_parser.py +60 -0
  17. sglang/srt/layers/activation.py +2 -2
  18. sglang/srt/layers/attention/flashattention_backend.py +781 -150
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
  21. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  22. sglang/srt/layers/dp_attention.py +1 -1
  23. sglang/srt/layers/layernorm.py +19 -4
  24. sglang/srt/layers/moe/ep_moe/layer.py +2 -0
  25. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  26. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  27. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  28. sglang/srt/layers/quantization/deep_gemm.py +378 -0
  29. sglang/srt/layers/quantization/fp8_kernel.py +7 -38
  30. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  31. sglang/srt/layers/quantization/gptq.py +13 -7
  32. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  33. sglang/srt/layers/quantization/w8a8_int8.py +3 -3
  34. sglang/srt/layers/rotary_embedding.py +6 -6
  35. sglang/srt/layers/sampler.py +2 -2
  36. sglang/srt/managers/data_parallel_controller.py +7 -1
  37. sglang/srt/managers/io_struct.py +14 -3
  38. sglang/srt/managers/schedule_batch.py +13 -0
  39. sglang/srt/managers/scheduler.py +16 -6
  40. sglang/srt/managers/tokenizer_manager.py +115 -29
  41. sglang/srt/managers/tp_worker.py +1 -0
  42. sglang/srt/mem_cache/hiradix_cache.py +40 -32
  43. sglang/srt/mem_cache/memory_pool.py +31 -13
  44. sglang/srt/model_executor/cuda_graph_runner.py +13 -8
  45. sglang/srt/model_executor/model_runner.py +19 -4
  46. sglang/srt/models/deepseek_v2.py +9 -6
  47. sglang/srt/models/minicpm3.py +2 -2
  48. sglang/srt/models/minicpmo.py +17 -6
  49. sglang/srt/openai_api/adapter.py +71 -4
  50. sglang/srt/openai_api/protocol.py +6 -1
  51. sglang/srt/server_args.py +52 -40
  52. sglang/srt/speculative/build_eagle_tree.py +2 -2
  53. sglang/srt/speculative/eagle_utils.py +2 -2
  54. sglang/srt/speculative/eagle_worker.py +2 -7
  55. sglang/srt/utils.py +46 -5
  56. sglang/test/test_utils.py +3 -1
  57. sglang/version.py +1 -1
  58. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +3 -3
  59. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +62 -57
  60. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +0 -0
  61. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
  62. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
sglang/bench_serving.py CHANGED
@@ -690,7 +690,6 @@ def sample_random_requests(
690
690
  dataset_path: str,
691
691
  random_sample: bool = True,
692
692
  ) -> List[Tuple[str, int, int]]:
693
-
694
693
  input_lens = np.random.randint(
695
694
  max(int(input_len * range_ratio), 1),
696
695
  input_len + 1,
@@ -1025,7 +1024,9 @@ async def benchmark(
1025
1024
  warmup_outputs = await asyncio.gather(*warmup_tasks)
1026
1025
 
1027
1026
  # Check if at least one warmup request succeeded
1028
- if not any(output.success for output in warmup_outputs):
1027
+ if args.warmup_requests > 0 and not any(
1028
+ output.success for output in warmup_outputs
1029
+ ):
1029
1030
  raise ValueError(
1030
1031
  "Warmup failed - Please make sure benchmark arguments "
1031
1032
  f"are correctly specified. Error: {warmup_outputs[0].error}"
@@ -0,0 +1,136 @@
1
+ """
2
+ Compile DeepGEMM Kernels for a model with specify server arguments
3
+
4
+ This script launches a server for capturing DeepGEMM calls and then compiles the kernels.
5
+ It accepts server arguments (the same as launch_server.py).
6
+
7
+ Usage:
8
+ python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code
9
+
10
+ """
11
+
12
+ import argparse
13
+ import dataclasses
14
+ import multiprocessing
15
+ import os
16
+ import time
17
+
18
+ import requests
19
+
20
+ from sglang.srt.entrypoints.http_server import launch_server
21
+ from sglang.srt.managers.io_struct import GenerateReqInput
22
+ from sglang.srt.managers.tokenizer_manager import TokenizerManager
23
+ from sglang.srt.server_args import ServerArgs
24
+ from sglang.srt.utils import kill_process_tree
25
+ from sglang.srt.warmup import warmup
26
+
27
+ multiprocessing.set_start_method("spawn", force=True)
28
+
29
+ # Reduce warning
30
+ os.environ["SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE"] = "1"
31
+
32
+
33
+ @dataclasses.dataclass
34
+ class CompileArgs:
35
+ timeout: int = 3600
36
+
37
+ @staticmethod
38
+ def add_cli_args(parser: argparse.ArgumentParser):
39
+ parser.add_argument("--timeout", type=int, default=CompileArgs.timeout)
40
+
41
+ @classmethod
42
+ def from_cli_args(cls, args: argparse.Namespace):
43
+ # use the default value's type to cast the args into correct types.
44
+ attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
45
+ return cls(
46
+ **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
47
+ )
48
+
49
+
50
+ @warmup("compile-deep-gemm")
51
+ async def warm_up_compile(tokenizer_manager: TokenizerManager):
52
+ print("\nGenerate warm up request for compiling DeepGEMM...\n")
53
+ generate_req_input = GenerateReqInput(
54
+ input_ids=[0, 1, 2, 3],
55
+ sampling_params={
56
+ "temperature": 0.0,
57
+ "max_new_tokens": 8,
58
+ "ignore_eos": True,
59
+ },
60
+ )
61
+ await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
62
+
63
+
64
+ def launch_server_internal(server_args):
65
+ try:
66
+ launch_server(server_args)
67
+ except Exception as e:
68
+ raise e
69
+ finally:
70
+ kill_process_tree(os.getpid(), include_parent=False)
71
+
72
+
73
+ def launch_server_process_and_send_one_request(
74
+ server_args: ServerArgs, compile_args: CompileArgs
75
+ ):
76
+ proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
77
+ proc.start()
78
+ base_url = f"http://{server_args.host}:{server_args.port}"
79
+ timeout = compile_args.timeout
80
+
81
+ start_time = time.time()
82
+ while time.time() - start_time < timeout:
83
+ try:
84
+ headers = {
85
+ "Content-Type": "application/json; charset=utf-8",
86
+ }
87
+ response = requests.get(f"{base_url}/v1/models", headers=headers)
88
+ if response.status_code == 200:
89
+ return proc
90
+ except requests.RequestException:
91
+ pass
92
+ time.sleep(10)
93
+ raise TimeoutError(
94
+ "DeepGEMM Kernels compilation timeout."
95
+ "\n\nFeel free and please restart the command."
96
+ )
97
+
98
+
99
+ def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs):
100
+ # Disbale cuda graph and torch compile to save time
101
+ server_args.disable_cuda_graph = True
102
+ server_args.enable_torch_compile = False
103
+ print(f"Disable CUDA Graph and Torch Compile to save time...")
104
+
105
+ # Set watchdog timeout to compile_args.timeout because compilation will take a long time
106
+ server_args.watchdog_timeout = compile_args.timeout
107
+ server_args.warmups = "compile-deep-gemm"
108
+
109
+
110
+ def run_compile(server_args: ServerArgs, compile_args: CompileArgs):
111
+ print(
112
+ "Begin DeepGEMM Kernels compilation...\n"
113
+ "It may take a long time and timeout maybe raised "
114
+ "while the compilation is still in progress.\n"
115
+ "Just feel free to restart the command "
116
+ "until the compilation is fully finished.\n"
117
+ )
118
+
119
+ proc = launch_server_process_and_send_one_request(server_args, compile_args)
120
+
121
+ kill_process_tree(proc.pid)
122
+
123
+ print("\nDeepGEMM Kernels compilation finished successfully.")
124
+
125
+
126
+ if __name__ == "__main__":
127
+ parser = argparse.ArgumentParser()
128
+ ServerArgs.add_cli_args(parser)
129
+ CompileArgs.add_cli_args(parser)
130
+ args = parser.parse_args()
131
+ server_args = ServerArgs.from_cli_args(args)
132
+ compile_args = CompileArgs.from_cli_args(args)
133
+
134
+ refine_server_args(server_args, compile_args)
135
+
136
+ run_compile(server_args, compile_args)
@@ -161,7 +161,11 @@ class OpenAI(BaseBackend):
161
161
  prompt = s.text_
162
162
 
163
163
  kwargs = sampling_params.to_openai_kwargs()
164
- if self.model_name.startswith("o1") or self.model_name.startswith("o3"):
164
+ if (
165
+ self.model_name.startswith("o1")
166
+ or self.model_name.startswith("o3")
167
+ or "o1" in self.model_name
168
+ ):
165
169
  kwargs.pop("max_tokens", None)
166
170
  else:
167
171
  kwargs.pop("max_completion_tokens", None)
@@ -324,7 +324,11 @@ class RuntimeEndpoint(BaseBackend):
324
324
 
325
325
  def _assert_success(self, res):
326
326
  if res.status_code != 200:
327
- raise RuntimeError(res.json())
327
+ try:
328
+ content = res.json()
329
+ except json.JSONDecodeError:
330
+ content = res.text
331
+ raise RuntimeError(content)
328
332
 
329
333
 
330
334
  def compute_normalized_prompt_logprobs(input_logprobs):
@@ -73,8 +73,11 @@ class ModelConfig:
73
73
  )
74
74
 
75
75
  if enable_multimodal is None:
76
- if self.hf_config.architectures == "Llama4ForConditionalGeneration":
76
+ if self.hf_config.architectures[0] == "Llama4ForConditionalGeneration":
77
77
  enable_multimodal = False
78
+ logger.info(
79
+ "Multimodal is disabled for Llama4. To enable it, set --enable-llama4-multimodal."
80
+ )
78
81
  else:
79
82
  enable_multimodal = True
80
83
 
@@ -158,6 +158,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
158
158
  def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
159
159
  try:
160
160
  if key_string == "$$ANY$$":
161
+ # Note: This builtin JSON grammar includes *all* valid JSON (including, for example, arrays at the root)
161
162
  ctx = self.grammar_compiler.compile_builtin_json_grammar()
162
163
  else:
163
164
  ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
21
21
  from __future__ import annotations
22
22
 
23
23
  import logging
24
+ from collections import deque
24
25
  from dataclasses import dataclass
25
26
  from typing import TYPE_CHECKING, List, Optional, Tuple
26
27
 
@@ -475,6 +476,48 @@ class SchedulerDisaggregationDecodeMixin:
475
476
 
476
477
  self.last_batch = batch
477
478
 
479
+ @torch.no_grad()
480
+ def event_loop_overlap_disagg_decode(self):
481
+ result_queue = deque()
482
+ self.last_batch: Optional[ScheduleBatch] = None
483
+ self.last_batch_is_extend = False # last batch is modifed in-place, so we need another variable to track if it's extend
484
+
485
+ while True:
486
+ recv_reqs = self.recv_requests()
487
+ self.process_input_requests(recv_reqs)
488
+ # polling and allocating kv cache
489
+ self.process_decode_queue()
490
+ batch = self.get_next_disagg_decode_batch_to_run()
491
+ self.cur_batch = batch
492
+ last_batch_is_extend = False
493
+
494
+ if batch:
495
+ # Generate fake extend output.
496
+ if batch.forward_mode.is_extend():
497
+ # Note: Logprobs should be handled on the prefill engine.
498
+ self.stream_output(batch.reqs, False)
499
+ last_batch_is_extend = True
500
+ else:
501
+ result = self.run_batch(batch)
502
+ result_queue.append((batch.copy(), result))
503
+
504
+ # Process the results of the previous batch but skip if the last batch is extend
505
+ if self.last_batch and not self.last_batch_is_extend:
506
+ tmp_batch, tmp_result = result_queue.popleft()
507
+ self.process_batch_result(tmp_batch, tmp_result)
508
+
509
+ if batch is None and (
510
+ len(self.disagg_decode_transfer_queue.queue)
511
+ + len(self.disagg_decode_prealloc_queue.queue)
512
+ == 0
513
+ ):
514
+ # When the server is idle, do self-check and re-init some states
515
+ self.check_memory()
516
+ self.new_token_ratio = self.init_new_token_ratio
517
+
518
+ self.last_batch = batch
519
+ self.last_batch_is_extend = last_batch_is_extend
520
+
478
521
  def get_next_disagg_decode_batch_to_run(
479
522
  self: Scheduler,
480
523
  ) -> Optional[Tuple[ScheduleBatch, bool]]:
@@ -23,8 +23,9 @@ class MiniLoadBalancer:
23
23
  return random.choice(self.prefill_servers), random.choice(self.decode_servers)
24
24
 
25
25
  async def generate(
26
- self, modified_request, prefill_server, decode_server
26
+ self, modified_request, prefill_server, decode_server, endpoint
27
27
  ) -> ORJSONResponse:
28
+ assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
28
29
 
29
30
  async with aiohttp.ClientSession(
30
31
  timeout=aiohttp.ClientTimeout(
@@ -32,8 +33,8 @@ class MiniLoadBalancer:
32
33
  ) # Add timeout for request reliability
33
34
  ) as session:
34
35
  tasks = [
35
- session.post(f"{prefill_server}/generate", json=modified_request),
36
- session.post(f"{decode_server}/generate", json=modified_request),
36
+ session.post(f"{prefill_server}/{endpoint}", json=modified_request),
37
+ session.post(f"{decode_server}/{endpoint}", json=modified_request),
37
38
  ]
38
39
  # Wait for both responses to complete. Prefill should end first.
39
40
  prefill_response, decode_response = await asyncio.gather(*tasks)
@@ -43,7 +44,11 @@ class MiniLoadBalancer:
43
44
  status_code=decode_response.status,
44
45
  )
45
46
 
46
- async def generate_stream(self, modified_request, prefill_server, decode_server):
47
+ async def generate_stream(
48
+ self, modified_request, prefill_server, decode_server, endpoint="generate"
49
+ ):
50
+ assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
51
+
47
52
  async def stream_results():
48
53
  async with aiohttp.ClientSession(
49
54
  timeout=aiohttp.ClientTimeout(
@@ -54,10 +59,10 @@ class MiniLoadBalancer:
54
59
  # Create the tasks for both prefill and decode requests
55
60
  tasks = [
56
61
  session.post(
57
- f"{prefill_server}/generate", json=modified_request
62
+ f"{prefill_server}/{endpoint}", json=modified_request
58
63
  ),
59
64
  session.post(
60
- f"{decode_server}/generate", json=modified_request
65
+ f"{decode_server}/{endpoint}", json=modified_request
61
66
  ),
62
67
  ]
63
68
  # Wait for both responses to complete. Since this is streaming, they return immediately.
@@ -157,6 +162,43 @@ async def get_model_info():
157
162
  async def handle_generate_request(request_data: dict):
158
163
  prefill_server, decode_server = load_balancer.select_pair()
159
164
 
165
+ # Parse and transform prefill_server for bootstrap data
166
+ parsed_url = urllib.parse.urlparse(prefill_server)
167
+ hostname = parsed_url.hostname
168
+ modified_request = request_data.copy()
169
+
170
+ batch_size = _get_request_batch_size(modified_request)
171
+ if batch_size is not None:
172
+ modified_request.update(
173
+ {
174
+ "bootstrap_host": [hostname] * batch_size,
175
+ "bootstrap_room": [
176
+ _generate_bootstrap_room() for _ in range(batch_size)
177
+ ],
178
+ }
179
+ )
180
+ else:
181
+ modified_request.update(
182
+ {
183
+ "bootstrap_host": hostname,
184
+ "bootstrap_room": _generate_bootstrap_room(),
185
+ }
186
+ )
187
+
188
+ if request_data.get("stream", False):
189
+ return await load_balancer.generate_stream(
190
+ modified_request, prefill_server, decode_server, "generate"
191
+ )
192
+ else:
193
+ return await load_balancer.generate(
194
+ modified_request, prefill_server, decode_server, "generate"
195
+ )
196
+
197
+
198
+ @app.post("/v1/chat/completions")
199
+ async def handle_completion_request(request_data: dict):
200
+ prefill_server, decode_server = load_balancer.select_pair()
201
+
160
202
  # Parse and transform prefill_server for bootstrap data
161
203
  parsed_url = urllib.parse.urlparse(prefill_server)
162
204
  hostname = parsed_url.hostname
@@ -170,14 +212,33 @@ async def handle_generate_request(request_data: dict):
170
212
 
171
213
  if request_data.get("stream", False):
172
214
  return await load_balancer.generate_stream(
173
- modified_request, prefill_server, decode_server
215
+ modified_request,
216
+ prefill_server,
217
+ decode_server,
218
+ endpoint="v1/chat/completions",
174
219
  )
175
220
  else:
176
221
  return await load_balancer.generate(
177
- modified_request, prefill_server, decode_server
222
+ modified_request,
223
+ prefill_server,
224
+ decode_server,
225
+ endpoint="v1/chat/completions",
178
226
  )
179
227
 
180
228
 
229
+ def _generate_bootstrap_room():
230
+ return random.randint(0, 2**63 - 1)
231
+
232
+
233
+ # We may utilize `GenerateReqInput`'s logic later
234
+ def _get_request_batch_size(request):
235
+ if (text := request.get("text")) is not None:
236
+ return None if isinstance(text, str) else len(text)
237
+ if (input_ids := request.get("input_ids")) is not None:
238
+ return None if isinstance(input_ids[0], int) else len(input_ids)
239
+ return None
240
+
241
+
181
242
  @app.get("/v1/models")
182
243
  async def get_models():
183
244
  prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server
@@ -231,7 +231,7 @@ class MooncakeKVManager(BaseKVManager):
231
231
  chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]
232
232
  assert len(chunked_dst_kv_indice) == len(
233
233
  kv_chunk.prefill_kv_indices
234
- )
234
+ ), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
235
235
 
236
236
  ret = self.send_kvcache(
237
237
  req.mooncake_session_id,
@@ -0,0 +1 @@
1
+ from .conn import NixlKVBootstrapServer, NixlKVManager, NixlKVReceiver, NixlKVSender