sglang 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. sglang/bench_latency.py +31 -13
  2. sglang/bench_server_latency.py +21 -10
  3. sglang/bench_serving.py +101 -7
  4. sglang/global_config.py +0 -1
  5. sglang/srt/conversation.py +11 -2
  6. sglang/srt/layers/attention/__init__.py +27 -5
  7. sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
  8. sglang/srt/layers/attention/flashinfer_backend.py +352 -83
  9. sglang/srt/layers/attention/triton_backend.py +6 -4
  10. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
  11. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
  12. sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
  13. sglang/srt/layers/sampler.py +6 -2
  14. sglang/srt/managers/data_parallel_controller.py +177 -0
  15. sglang/srt/managers/detokenizer_manager.py +31 -10
  16. sglang/srt/managers/io_struct.py +11 -2
  17. sglang/srt/managers/schedule_batch.py +126 -43
  18. sglang/srt/managers/schedule_policy.py +2 -1
  19. sglang/srt/managers/scheduler.py +245 -142
  20. sglang/srt/managers/tokenizer_manager.py +14 -1
  21. sglang/srt/managers/tp_worker.py +111 -1
  22. sglang/srt/mem_cache/chunk_cache.py +8 -4
  23. sglang/srt/mem_cache/memory_pool.py +77 -4
  24. sglang/srt/mem_cache/radix_cache.py +15 -7
  25. sglang/srt/model_executor/cuda_graph_runner.py +4 -4
  26. sglang/srt/model_executor/forward_batch_info.py +16 -21
  27. sglang/srt/model_executor/model_runner.py +100 -36
  28. sglang/srt/models/baichuan.py +2 -3
  29. sglang/srt/models/chatglm.py +5 -6
  30. sglang/srt/models/commandr.py +1 -2
  31. sglang/srt/models/dbrx.py +1 -2
  32. sglang/srt/models/deepseek.py +4 -5
  33. sglang/srt/models/deepseek_v2.py +5 -6
  34. sglang/srt/models/exaone.py +1 -2
  35. sglang/srt/models/gemma.py +2 -2
  36. sglang/srt/models/gemma2.py +5 -5
  37. sglang/srt/models/gpt_bigcode.py +5 -5
  38. sglang/srt/models/grok.py +1 -2
  39. sglang/srt/models/internlm2.py +1 -2
  40. sglang/srt/models/llama.py +1 -2
  41. sglang/srt/models/llama_classification.py +1 -2
  42. sglang/srt/models/llama_reward.py +2 -3
  43. sglang/srt/models/llava.py +4 -8
  44. sglang/srt/models/llavavid.py +1 -2
  45. sglang/srt/models/minicpm.py +1 -2
  46. sglang/srt/models/minicpm3.py +5 -6
  47. sglang/srt/models/mixtral.py +1 -2
  48. sglang/srt/models/mixtral_quant.py +1 -2
  49. sglang/srt/models/olmo.py +352 -0
  50. sglang/srt/models/olmoe.py +1 -2
  51. sglang/srt/models/qwen.py +1 -2
  52. sglang/srt/models/qwen2.py +1 -2
  53. sglang/srt/models/qwen2_moe.py +4 -5
  54. sglang/srt/models/stablelm.py +1 -2
  55. sglang/srt/models/torch_native_llama.py +1 -2
  56. sglang/srt/models/xverse.py +1 -2
  57. sglang/srt/models/xverse_moe.py +4 -5
  58. sglang/srt/models/yivl.py +1 -2
  59. sglang/srt/openai_api/adapter.py +97 -52
  60. sglang/srt/openai_api/protocol.py +10 -2
  61. sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
  62. sglang/srt/sampling/sampling_batch_info.py +105 -59
  63. sglang/srt/sampling/sampling_params.py +2 -0
  64. sglang/srt/server.py +171 -37
  65. sglang/srt/server_args.py +127 -48
  66. sglang/srt/utils.py +37 -14
  67. sglang/test/few_shot_gsm8k.py +4 -1
  68. sglang/test/few_shot_gsm8k_engine.py +144 -0
  69. sglang/test/srt/sampling/penaltylib/utils.py +16 -12
  70. sglang/version.py +1 -1
  71. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/METADATA +82 -32
  72. sglang-0.3.4.dist-info/RECORD +143 -0
  73. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
  74. sglang/srt/layers/attention/flashinfer_utils.py +0 -237
  75. sglang-0.3.3.dist-info/RECORD +0 -139
  76. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
  77. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
@@ -26,7 +26,9 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
26
26
  context_attention_fwd,
27
27
  )
28
28
 
29
- CUDA_CAPABILITY = torch.cuda.get_device_capability()
29
+ is_cuda_available = torch.cuda.is_available()
30
+ if is_cuda_available:
31
+ CUDA_CAPABILITY = torch.cuda.get_device_capability()
30
32
 
31
33
 
32
34
  @triton.jit
@@ -286,12 +288,12 @@ def extend_attention_fwd(
286
288
  BLOCK_DPE = 0
287
289
  BLOCK_DV = triton.next_power_of_2(Lv)
288
290
 
289
- if CUDA_CAPABILITY[0] >= 9:
291
+ if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
290
292
  if Lq <= 256:
291
293
  BLOCK_M, BLOCK_N = (128, 64)
292
294
  else:
293
295
  BLOCK_M, BLOCK_N = (32, 64)
294
- elif CUDA_CAPABILITY[0] >= 8:
296
+ elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
295
297
  if Lq <= 128:
296
298
  BLOCK_M, BLOCK_N = (128, 128)
297
299
  elif Lq <= 256:
@@ -24,7 +24,9 @@ import torch
24
24
  import triton
25
25
  import triton.language as tl
26
26
 
27
- CUDA_CAPABILITY = torch.cuda.get_device_capability()
27
+ is_cuda_available = torch.cuda.is_available()
28
+ if is_cuda_available:
29
+ CUDA_CAPABILITY = torch.cuda.get_device_capability()
28
30
 
29
31
 
30
32
  @triton.jit
@@ -145,7 +147,7 @@ def _fwd_kernel(
145
147
 
146
148
 
147
149
  def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
148
- if CUDA_CAPABILITY[0] >= 8:
150
+ if is_cuda_available and CUDA_CAPABILITY[0] >= 8:
149
151
  BLOCK = 128
150
152
  else:
151
153
  BLOCK = 64
@@ -21,6 +21,10 @@ logger = logging.getLogger(__name__)
21
21
 
22
22
 
23
23
  class Sampler(nn.Module):
24
+ def __init__(self):
25
+ super().__init__()
26
+ self.use_nan_detectioin = not global_server_args_dict["disable_nan_detection"]
27
+
24
28
  def forward(
25
29
  self,
26
30
  logits: Union[torch.Tensor, LogitsProcessorOutput],
@@ -36,13 +40,13 @@ class Sampler(nn.Module):
36
40
  logits = None
37
41
  del logits
38
42
 
39
- if torch.any(torch.isnan(probs)):
43
+ if self.use_nan_detectioin and torch.any(torch.isnan(probs)):
40
44
  logger.warning("Detected errors during sampling! NaN in the probability.")
41
45
  probs = torch.where(
42
46
  torch.isnan(probs), torch.full_like(probs, 1e-10), probs
43
47
  )
44
48
 
45
- if sampling_info.top_ks.max().item() <= 1:
49
+ if sampling_info.is_all_greedy:
46
50
  # Use torch.argmax if all requests use greedy sampling
47
51
  batch_next_token_ids = torch.argmax(probs, -1)
48
52
  elif global_server_args_dict["sampling_backend"] == "flashinfer":
@@ -0,0 +1,177 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """A controller that dispatches requests to multiple data parallel workers."""
17
+
18
+ import logging
19
+ import multiprocessing as mp
20
+ from enum import Enum, auto
21
+
22
+ import zmq
23
+
24
+ from sglang.srt.managers.io_struct import (
25
+ TokenizedEmbeddingReqInput,
26
+ TokenizedGenerateReqInput,
27
+ TokenizedRewardReqInput,
28
+ )
29
+ from sglang.srt.managers.scheduler import run_scheduler_process
30
+ from sglang.srt.server_args import PortArgs, ServerArgs
31
+ from sglang.srt.utils import (
32
+ configure_logger,
33
+ kill_parent_process,
34
+ suppress_other_loggers,
35
+ )
36
+ from sglang.utils import get_exception_traceback
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ class LoadBalanceMethod(Enum):
42
+ """Load balance method."""
43
+
44
+ ROUND_ROBIN = auto()
45
+ SHORTEST_QUEUE = auto()
46
+
47
+ @classmethod
48
+ def from_str(cls, method: str):
49
+ method = method.upper()
50
+ try:
51
+ return cls[method]
52
+ except KeyError as exc:
53
+ raise ValueError(f"Invalid load balance method: {method}") from exc
54
+
55
+
56
+ class DataParallelController:
57
+ """A controller that dispatches requests to multiple data parallel workers."""
58
+
59
+ def __init__(self, server_args, port_args) -> None:
60
+ # Parse args
61
+ self.server_args = server_args
62
+ self.port_args = port_args
63
+ self.load_balance_method = LoadBalanceMethod.from_str(
64
+ server_args.load_balance_method
65
+ )
66
+
67
+ # Init inter-process communication
68
+ self.context = zmq.Context(1 + server_args.dp_size)
69
+ self.recv_from_tokenizer = self.context.socket(zmq.PULL)
70
+ self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}")
71
+
72
+ # Dispatch method
73
+ self.round_robin_counter = 0
74
+ dispatch_lookup = {
75
+ LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
76
+ LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
77
+ }
78
+ self.dispatching = dispatch_lookup[self.load_balance_method]
79
+
80
+ # Start data parallel workers
81
+ base_gpu_id = 0
82
+ self.workers = []
83
+ for dp_rank in range(server_args.dp_size):
84
+ tmp_port_args = PortArgs.init_new(server_args)
85
+ tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
86
+
87
+ send_to = self.launch_tensor_parallel_group(
88
+ server_args,
89
+ tmp_port_args,
90
+ base_gpu_id,
91
+ dp_rank,
92
+ )
93
+
94
+ self.workers.append(send_to)
95
+ base_gpu_id += server_args.tp_size
96
+
97
+ def launch_tensor_parallel_group(
98
+ self,
99
+ server_args: ServerArgs,
100
+ port_args: PortArgs,
101
+ base_gpu_id: int,
102
+ dp_rank: int,
103
+ ):
104
+ # Launch tensor parallel scheduler processes
105
+ scheduler_procs = []
106
+ scheduler_pipe_readers = []
107
+ tp_size_per_node = server_args.tp_size // server_args.nnodes
108
+ tp_rank_range = range(
109
+ tp_size_per_node * server_args.node_rank,
110
+ tp_size_per_node * (server_args.node_rank + 1),
111
+ )
112
+ for tp_rank in tp_rank_range:
113
+ reader, writer = mp.Pipe(duplex=False)
114
+ gpu_id = base_gpu_id + tp_rank % tp_size_per_node
115
+ proc = mp.Process(
116
+ target=run_scheduler_process,
117
+ args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
118
+ )
119
+ proc.start()
120
+ scheduler_procs.append(proc)
121
+ scheduler_pipe_readers.append(reader)
122
+
123
+ send_to = self.context.socket(zmq.PUSH)
124
+ send_to.connect(f"ipc://{port_args.scheduler_input_ipc_name}")
125
+
126
+ # Wait for model to finish loading
127
+ for i in range(len(scheduler_pipe_readers)):
128
+ scheduler_pipe_readers[i].recv()
129
+
130
+ return send_to
131
+
132
+ def round_robin_scheduler(self, req):
133
+ self.workers[self.round_robin_counter].send_pyobj(req)
134
+ self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
135
+
136
+ def shortest_queue_scheduler(self, input_requests):
137
+ raise NotImplementedError()
138
+
139
+ def event_loop(self):
140
+ while True:
141
+ while True:
142
+ try:
143
+ recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
144
+ except zmq.ZMQError:
145
+ break
146
+
147
+ if isinstance(
148
+ recv_req,
149
+ (
150
+ TokenizedGenerateReqInput,
151
+ TokenizedEmbeddingReqInput,
152
+ TokenizedRewardReqInput,
153
+ ),
154
+ ):
155
+ self.dispatching(recv_req)
156
+ else:
157
+ # Send other control messages to all workers
158
+ for worker in self.workers:
159
+ worker.queue.put(recv_req)
160
+
161
+
162
+ def run_data_parallel_controller_process(
163
+ server_args: ServerArgs,
164
+ port_args: PortArgs,
165
+ pipe_writer,
166
+ ):
167
+ configure_logger(server_args)
168
+ suppress_other_loggers()
169
+
170
+ try:
171
+ controller = DataParallelController(server_args, port_args)
172
+ pipe_writer.send("ready")
173
+ controller.event_loop()
174
+ except Exception:
175
+ msg = get_exception_traceback()
176
+ logger.error(msg)
177
+ kill_parent_process()
@@ -18,7 +18,7 @@ limitations under the License.
18
18
  import dataclasses
19
19
  import logging
20
20
  from collections import OrderedDict
21
- from typing import List
21
+ from typing import List, Union
22
22
 
23
23
  import zmq
24
24
 
@@ -29,7 +29,7 @@ from sglang.srt.managers.io_struct import (
29
29
  BatchTokenIDOut,
30
30
  UpdateWeightReqOutput,
31
31
  )
32
- from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
32
+ from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
33
33
  from sglang.srt.server_args import PortArgs, ServerArgs
34
34
  from sglang.srt.utils import configure_logger, kill_parent_process
35
35
  from sglang.utils import find_printable_text, get_exception_traceback
@@ -75,6 +75,21 @@ class DetokenizerManager:
75
75
 
76
76
  self.decode_status = LimitedCapacityDict()
77
77
 
78
+ def trim_eos(self, output: Union[str, List[int]], finished_reason, no_stop_trim):
79
+ if no_stop_trim:
80
+ return output
81
+
82
+ # Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
83
+ if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str):
84
+ pos = output.find(finished_reason.matched)
85
+ return output[:pos] if pos != -1 else output
86
+ if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance(
87
+ output, list
88
+ ):
89
+ assert len(output) > 0
90
+ return output[:-1]
91
+ return output
92
+
78
93
  def event_loop(self):
79
94
  """The event loop that handles requests"""
80
95
 
@@ -122,7 +137,13 @@ class DetokenizerManager:
122
137
  s = self.decode_status[rid]
123
138
  s.decode_ids = recv_obj.decode_ids[i]
124
139
 
125
- read_ids.append(s.decode_ids[s.surr_offset :])
140
+ read_ids.append(
141
+ self.trim_eos(
142
+ s.decode_ids[s.surr_offset :],
143
+ recv_obj.finished_reason[i],
144
+ recv_obj.no_stop_trim[i],
145
+ )
146
+ )
126
147
  surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
127
148
 
128
149
  # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
@@ -152,13 +173,13 @@ class DetokenizerManager:
152
173
  else:
153
174
  new_text = find_printable_text(new_text)
154
175
 
155
- output_strs.append(s.decoded_text + new_text)
156
-
157
- # Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
158
- if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
159
- pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
160
- if pos != -1:
161
- output_strs[i] = output_strs[i][:pos]
176
+ output_strs.append(
177
+ self.trim_eos(
178
+ s.decoded_text + new_text,
179
+ recv_obj.finished_reason[i],
180
+ recv_obj.no_stop_trim[i],
181
+ )
182
+ )
162
183
 
163
184
  self.send_to_tokenizer.send_pyobj(
164
185
  BatchStrOut(
@@ -20,6 +20,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
20
20
 
21
21
  import uuid
22
22
  from dataclasses import dataclass
23
+ from enum import Enum
23
24
  from typing import Dict, List, Optional, Union
24
25
 
25
26
  from sglang.srt.managers.schedule_batch import BaseFinishReason
@@ -55,6 +56,9 @@ class GenerateReqInput:
55
56
  # LoRA related
56
57
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
57
58
 
59
+ # Whether it is a single request or a batch request
60
+ is_single: bool = True
61
+
58
62
  def post_init(self):
59
63
  if (self.text is None and self.input_ids is None) or (
60
64
  self.text is not None and self.input_ids is not None
@@ -119,8 +123,7 @@ class GenerateReqInput:
119
123
  elif not isinstance(self.image_data, list):
120
124
  self.image_data = [self.image_data] * num
121
125
  elif isinstance(self.image_data, list):
122
- # FIXME incorrect order for duplication
123
- self.image_data = self.image_data * num
126
+ pass
124
127
 
125
128
  if self.sampling_params is None:
126
129
  self.sampling_params = [{}] * num
@@ -295,6 +298,7 @@ class BatchTokenIDOut:
295
298
  spaces_between_special_tokens: List[bool]
296
299
  meta_info: List[Dict]
297
300
  finished_reason: List[BaseFinishReason]
301
+ no_stop_trim: List[bool]
298
302
 
299
303
 
300
304
  @dataclass
@@ -344,3 +348,8 @@ class UpdateWeightReqOutput:
344
348
  class AbortReq:
345
349
  # The request id
346
350
  rid: str
351
+
352
+
353
+ class ProfileReq(Enum):
354
+ START_PROFILE = 1
355
+ STOP_PROFILE = 2