sglang 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +33 -26
  2. sglang/api.py +9 -1
  3. sglang/bench_latency.py +2 -2
  4. sglang/bench_serving.py +10 -1
  5. sglang/check_env.py +1 -1
  6. sglang/lang/backend/litellm.py +1 -1
  7. sglang/lang/backend/openai.py +1 -1
  8. sglang/lang/backend/runtime_endpoint.py +4 -4
  9. sglang/lang/interpreter.py +24 -9
  10. sglang/lang/ir.py +1 -1
  11. sglang/srt/constrained/__init__.py +15 -0
  12. sglang/srt/constrained/base_cache.py +15 -0
  13. sglang/srt/constrained/fsm_cache.py +36 -1
  14. sglang/srt/constrained/jump_forward.py +15 -0
  15. sglang/srt/conversation.py +26 -0
  16. sglang/srt/hf_transformers_utils.py +18 -1
  17. sglang/srt/layers/context_flashattention_nopad.py +15 -0
  18. sglang/srt/layers/extend_attention.py +15 -0
  19. sglang/srt/layers/fused_moe.py +15 -0
  20. sglang/srt/layers/linear.py +15 -0
  21. sglang/srt/layers/logits_processor.py +109 -72
  22. sglang/srt/layers/quantization/__init__.py +15 -0
  23. sglang/srt/layers/quantization/fp8.py +15 -0
  24. sglang/srt/layers/radix_attention.py +21 -3
  25. sglang/srt/layers/token_attention.py +16 -1
  26. sglang/srt/managers/{controller/manager_multi.py → controller_multi.py} +17 -2
  27. sglang/srt/managers/{controller/manager_single.py → controller_single.py} +17 -2
  28. sglang/srt/managers/detokenizer_manager.py +16 -1
  29. sglang/srt/managers/io_struct.py +38 -5
  30. sglang/srt/managers/{controller/schedule_heuristic.py → policy_scheduler.py} +37 -22
  31. sglang/srt/managers/{controller/infer_batch.py → schedule_batch.py} +85 -25
  32. sglang/srt/managers/tokenizer_manager.py +99 -57
  33. sglang/srt/managers/{controller/tp_worker.py → tp_worker.py} +177 -81
  34. sglang/srt/mem_cache/flush_cache.py +33 -0
  35. sglang/srt/{memory_pool.py → mem_cache/memory_pool.py} +16 -1
  36. sglang/srt/{managers/controller → mem_cache}/radix_cache.py +15 -0
  37. sglang/srt/mm_utils.py +15 -0
  38. sglang/srt/model_config.py +20 -0
  39. sglang/srt/{managers/controller → model_executor}/cuda_graph_runner.py +42 -18
  40. sglang/srt/{managers/controller → model_executor}/model_runner.py +51 -16
  41. sglang/srt/model_loader/model_loader.py +15 -0
  42. sglang/srt/model_loader/utils.py +16 -1
  43. sglang/srt/models/chatglm.py +16 -1
  44. sglang/srt/models/commandr.py +16 -1
  45. sglang/srt/models/dbrx.py +16 -1
  46. sglang/srt/models/deepseek.py +16 -1
  47. sglang/srt/models/deepseek_v2.py +532 -0
  48. sglang/srt/models/gemma.py +16 -1
  49. sglang/srt/models/gemma2.py +16 -1
  50. sglang/srt/models/gpt_bigcode.py +16 -1
  51. sglang/srt/models/grok.py +16 -1
  52. sglang/srt/models/internlm2.py +16 -1
  53. sglang/srt/models/llama2.py +16 -1
  54. sglang/srt/models/llama_classification.py +19 -4
  55. sglang/srt/models/llava.py +17 -2
  56. sglang/srt/models/llavavid.py +17 -2
  57. sglang/srt/models/minicpm.py +16 -1
  58. sglang/srt/models/mistral.py +15 -0
  59. sglang/srt/models/mixtral.py +16 -1
  60. sglang/srt/models/mixtral_quant.py +16 -1
  61. sglang/srt/models/qwen.py +16 -1
  62. sglang/srt/models/qwen2.py +16 -1
  63. sglang/srt/models/qwen2_moe.py +16 -1
  64. sglang/srt/models/stablelm.py +16 -1
  65. sglang/srt/models/yivl.py +15 -0
  66. sglang/srt/openai_api/adapter.py +545 -160
  67. sglang/srt/openai_api/protocol.py +65 -1
  68. sglang/srt/sampling_params.py +20 -4
  69. sglang/srt/server.py +90 -37
  70. sglang/srt/server_args.py +76 -17
  71. sglang/srt/utils.py +15 -0
  72. sglang/test/test_programs.py +5 -1
  73. sglang/utils.py +22 -0
  74. sglang/version.py +1 -1
  75. {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/METADATA +40 -12
  76. sglang-0.2.7.dist-info/RECORD +93 -0
  77. {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/WHEEL +1 -1
  78. sglang/srt/flush_cache.py +0 -18
  79. sglang-0.2.5.dist-info/RECORD +0 -92
  80. {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/LICENSE +0 -0
  81. {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,22 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """Logits processing."""
2
17
 
3
18
  import dataclasses
4
- from typing import List, Union
19
+ from typing import List, Optional, Union
5
20
 
6
21
  import torch
7
22
  from torch import nn
@@ -10,7 +25,7 @@ from vllm.distributed import (
10
25
  tensor_model_parallel_all_gather,
11
26
  )
12
27
 
13
- from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
28
+ from sglang.srt.model_executor.model_runner import ForwardMode, InputMetadata
14
29
 
15
30
 
16
31
  @dataclasses.dataclass
@@ -22,23 +37,23 @@ class LogitProcessorOutput:
22
37
 
23
38
  # The normlaized logprobs of prompts. shape: [#seq]
24
39
  normalized_prompt_logprobs: torch.Tensor
25
- # The logprobs of prefill tokens. shape: [#token, vocab_size]
26
- prefill_token_logprobs: torch.Tensor
40
+ # The logprobs of input tokens. shape: [#token, vocab_size]
41
+ input_token_logprobs: torch.Tensor
27
42
 
28
- # The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
29
- prefill_top_logprobs: List
30
- # The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
31
- decode_top_logprobs: List
43
+ # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
44
+ input_top_logprobs: List
45
+ # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
46
+ output_top_logprobs: List
32
47
 
33
48
 
34
49
  @dataclasses.dataclass
35
50
  class LogitsMetadata:
36
51
  forward_mode: ForwardMode
37
- return_logprob: bool
52
+ return_logprob: bool = False
38
53
 
39
- extend_seq_lens: torch.Tensor = None
40
- extend_start_loc: torch.Tensor = None
41
- top_logprobs_nums: List[int] = None
54
+ extend_seq_lens: Optional[torch.Tensor] = None
55
+ extend_start_loc: Optional[torch.Tensor] = None
56
+ top_logprobs_nums: Optional[List[int]] = None
42
57
 
43
58
  @classmethod
44
59
  def from_input_metadata(cls, input_metadata: InputMetadata):
@@ -58,20 +73,16 @@ class LogitsProcessor(nn.Module):
58
73
  self.tp_size = get_tensor_model_parallel_world_size()
59
74
 
60
75
  def _get_normalized_prompt_logprobs(
61
- self, prefill_token_logprobs, logits_metadata: LogitsMetadata
76
+ self, input_token_logprobs, logits_metadata: LogitsMetadata
62
77
  ):
63
- logprobs_cumsum = torch.cumsum(
64
- prefill_token_logprobs, dim=0, dtype=torch.float32
65
- )
78
+ logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
66
79
 
67
80
  start = logits_metadata.extend_start_loc.clone()
68
81
  end = start + logits_metadata.extend_seq_lens - 2
69
- start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
70
- end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
82
+ start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
83
+ end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
71
84
  sum_logp = (
72
- logprobs_cumsum[end]
73
- - logprobs_cumsum[start]
74
- + prefill_token_logprobs[start]
85
+ logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
75
86
  )
76
87
  normalized_prompt_logprobs = sum_logp / (
77
88
  (logits_metadata.extend_seq_lens - 1).clamp(min=1)
@@ -79,37 +90,51 @@ class LogitsProcessor(nn.Module):
79
90
 
80
91
  return normalized_prompt_logprobs
81
92
 
82
- def _get_top_logprobs(self, all_logprobs, logits_metadata: LogitsMetadata):
83
- # TODO: vectorize the code below
93
+ @staticmethod
94
+ def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
84
95
  if logits_metadata.forward_mode == ForwardMode.DECODE:
85
- decode_top_logprobs = []
86
- for i in range(all_logprobs.shape[0]):
87
- k = logits_metadata.top_logprobs_nums[i]
88
- t = all_logprobs[i].topk(k)
89
- v_cpu = t.values.tolist()
90
- p_cpu = t.indices.tolist()
91
- decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
92
- return None, decode_top_logprobs
96
+ output_top_logprobs = []
97
+ max_k = max(logits_metadata.top_logprobs_nums)
98
+ ret = all_logprobs.topk(max_k, dim=1)
99
+ values = ret.values.tolist()
100
+ indices = ret.indices.tolist()
101
+ for i, k in enumerate(logits_metadata.top_logprobs_nums):
102
+ output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k])))
103
+ return None, output_top_logprobs
93
104
  else:
94
- prefill_top_logprobs, decode_top_logprobs = [], []
105
+ # TODO: vectorize the code below
106
+ input_top_logprobs, output_top_logprobs = [], []
95
107
  pt = 0
96
108
  extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
109
+
110
+ max_k = max(logits_metadata.top_logprobs_nums)
111
+ ret = all_logprobs.topk(max_k, dim=1)
112
+ values = ret.values.tolist()
113
+ indices = ret.indices.tolist()
114
+
97
115
  for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
98
116
  if extend_seq_len == 0:
99
- prefill_top_logprobs.append([])
100
- decode_top_logprobs.append([])
117
+ input_top_logprobs.append([])
118
+ output_top_logprobs.append([])
101
119
  continue
102
120
  k = logits_metadata.top_logprobs_nums[i]
103
- t = all_logprobs[pt : pt + extend_seq_len].topk(k)
104
- vs_cpu = t.values.tolist()
105
- ps_cpu = t.indices.tolist()
106
- prefill_top_logprobs.append(
107
- [list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
121
+ input_top_logprobs.append(
122
+ [
123
+ list(zip(values[pt + j][:k], indices[pt + j][:k]))
124
+ for j in range(extend_seq_len - 1)
125
+ ]
126
+ )
127
+ output_top_logprobs.append(
128
+ list(
129
+ zip(
130
+ values[pt + extend_seq_len - 1][:k],
131
+ indices[pt + extend_seq_len - 1][:k],
132
+ )
133
+ )
108
134
  )
109
- decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
110
135
  pt += extend_seq_len
111
136
 
112
- return prefill_top_logprobs, decode_top_logprobs
137
+ return input_top_logprobs, output_top_logprobs
113
138
 
114
139
  def forward(
115
140
  self,
@@ -136,7 +161,7 @@ class LogitsProcessor(nn.Module):
136
161
  last_logits = torch.matmul(last_hidden, weight.T)
137
162
  if self.tp_size > 1:
138
163
  last_logits = tensor_model_parallel_all_gather(last_logits)
139
- last_logits = last_logits[:, : self.config.vocab_size]
164
+ last_logits = last_logits[:, : self.config.vocab_size].float()
140
165
 
141
166
  if hasattr(self.config, "final_logit_softcapping"):
142
167
  last_logits /= self.config.final_logit_softcapping
@@ -149,63 +174,75 @@ class LogitsProcessor(nn.Module):
149
174
  next_token_logits=last_logits,
150
175
  next_token_logprobs=None,
151
176
  normalized_prompt_logprobs=None,
152
- prefill_token_logprobs=None,
153
- prefill_top_logprobs=None,
154
- decode_top_logprobs=None,
177
+ input_token_logprobs=None,
178
+ input_top_logprobs=None,
179
+ output_top_logprobs=None,
155
180
  )
156
181
  else:
157
182
  # When logprob is requested, compute the logits for all tokens.
158
183
  if logits_metadata.forward_mode == ForwardMode.DECODE:
159
- all_logits = last_logits
160
- else:
161
- all_logits = torch.matmul(hidden_states, weight.T)
162
- if self.tp_size > 1:
163
- all_logits = tensor_model_parallel_all_gather(all_logits)
164
- all_logits = all_logits[:, : self.config.vocab_size]
184
+ last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
165
185
 
166
- all_logprobs = all_logits.float()
167
- del all_logits
168
- all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
169
-
170
- # Get the logprob of top-k tokens
171
- return_top_logprob = any(x > 0 for x in logits_metadata.top_logprobs_nums)
172
- if return_top_logprob:
173
- prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
174
- all_logprobs, logits_metadata
186
+ # Get the logprob of top-k tokens
187
+ return_top_logprob = any(
188
+ x > 0 for x in logits_metadata.top_logprobs_nums
175
189
  )
176
- else:
177
- prefill_top_logprobs = decode_top_logprobs = None
190
+ if return_top_logprob:
191
+ output_top_logprobs = self.get_top_logprobs(
192
+ last_logprobs, logits_metadata
193
+ )[1]
194
+ else:
195
+ output_top_logprobs = None
178
196
 
179
- if logits_metadata.forward_mode == ForwardMode.DECODE:
180
197
  return LogitProcessorOutput(
181
198
  next_token_logits=last_logits,
182
- next_token_logprobs=all_logprobs,
199
+ next_token_logprobs=last_logprobs,
183
200
  normalized_prompt_logprobs=None,
184
- prefill_token_logprobs=None,
185
- prefill_top_logprobs=None,
186
- decode_top_logprobs=decode_top_logprobs,
201
+ input_token_logprobs=None,
202
+ input_top_logprobs=None,
203
+ output_top_logprobs=output_top_logprobs,
187
204
  )
188
205
  else:
206
+ all_logits = torch.matmul(hidden_states, weight.T)
207
+ if self.tp_size > 1:
208
+ all_logits = tensor_model_parallel_all_gather(all_logits)
209
+ all_logits = all_logits[:, : self.config.vocab_size].float()
210
+
211
+ all_logprobs = all_logits
212
+ del all_logits
213
+ all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
214
+
215
+ # Get the logprob of top-k tokens
216
+ return_top_logprob = any(
217
+ x > 0 for x in logits_metadata.top_logprobs_nums
218
+ )
219
+ if return_top_logprob:
220
+ input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
221
+ all_logprobs, logits_metadata
222
+ )
223
+ else:
224
+ input_top_logprobs = output_top_logprobs = None
225
+
189
226
  last_logprobs = all_logprobs[last_index]
190
227
 
191
228
  # Compute the logprobs and normalized logprobs for the prefill tokens.
192
229
  # Note that we pad a zero at the end of each sequence for easy computation.
193
- prefill_token_logprobs = all_logprobs[
230
+ input_token_logprobs = all_logprobs[
194
231
  torch.arange(all_logprobs.shape[0], device="cuda"),
195
232
  torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
196
233
  ]
197
234
 
198
235
  normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
199
- prefill_token_logprobs, logits_metadata
236
+ input_token_logprobs, logits_metadata
200
237
  )
201
238
 
202
239
  return LogitProcessorOutput(
203
240
  next_token_logits=last_logits,
204
241
  next_token_logprobs=last_logprobs,
205
242
  normalized_prompt_logprobs=normalized_prompt_logprobs,
206
- prefill_token_logprobs=prefill_token_logprobs,
207
- prefill_top_logprobs=prefill_top_logprobs,
208
- decode_top_logprobs=decode_top_logprobs,
243
+ input_token_logprobs=input_token_logprobs,
244
+ input_top_logprobs=input_top_logprobs,
245
+ output_top_logprobs=output_top_logprobs,
209
246
  )
210
247
 
211
248
 
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  # temporarily adapted from vLLM
2
17
  # FIXME: in progress of refactoring the model loader
3
18
 
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  # adapted from https://github.com/vllm-project/vllm/blob/e76466dde2bc9525d55165ceaa600d298c7bf773/vllm/model_executor/layers/quantization/fp8.py
2
17
  # FIXME refactor in progress
3
18
  from typing import Any, Dict, List, Optional, Union
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """Radix attention."""
2
17
 
3
18
  import torch
@@ -7,8 +22,11 @@ from torch import nn
7
22
  from sglang.global_config import global_config
8
23
  from sglang.srt.layers.extend_attention import extend_attention_fwd
9
24
  from sglang.srt.layers.token_attention import token_attention_fwd
10
- from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
11
- from sglang.srt.server import global_server_args_dict
25
+ from sglang.srt.model_executor.model_runner import (
26
+ ForwardMode,
27
+ InputMetadata,
28
+ global_server_args_dict,
29
+ )
12
30
 
13
31
 
14
32
  class RadixAttention(nn.Module):
@@ -85,7 +103,7 @@ class RadixAttention(nn.Module):
85
103
  return o
86
104
 
87
105
  def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
88
- if not input_metadata.use_ragged:
106
+ if not input_metadata.flashinfer_use_ragged:
89
107
  self.store_kv_cache(k, v, input_metadata)
90
108
 
91
109
  o = input_metadata.flashinfer_prefill_wrapper_paged.forward(
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  # Adapted from
2
17
  # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
3
18
  # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
@@ -5,7 +20,7 @@ import torch
5
20
  import triton
6
21
  import triton.language as tl
7
22
 
8
- from sglang.srt.server import global_server_args_dict
23
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
9
24
 
10
25
  if global_server_args_dict.get("attention_reduce_in_fp32", False):
11
26
  REDUCE_TRITON_TYPE = tl.float32
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """
2
17
  A controller that manages multiple data parallel workers.
3
18
  Each data parallel worker can manage multiple tensor parallel workers.
@@ -12,7 +27,7 @@ from enum import Enum, auto
12
27
  import numpy as np
13
28
  import zmq
14
29
 
15
- from sglang.srt.managers.controller.manager_single import (
30
+ from sglang.srt.managers.controller_single import (
16
31
  start_controller_process as start_controller_process_single,
17
32
  )
18
33
  from sglang.srt.managers.io_struct import (
@@ -24,7 +39,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
24
39
  from sglang.srt.utils import kill_parent_process
25
40
  from sglang.utils import get_exception_traceback
26
41
 
27
- logger = logging.getLogger("srt.controller")
42
+ logger = logging.getLogger(__name__)
28
43
 
29
44
 
30
45
  class LoadBalanceMethod(Enum):
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """A controller that manages a group of tensor parallel workers."""
2
17
 
3
18
  import logging
@@ -7,7 +22,7 @@ from typing import List
7
22
 
8
23
  import zmq
9
24
 
10
- from sglang.srt.managers.controller.tp_worker import (
25
+ from sglang.srt.managers.tp_worker import (
11
26
  ModelTpServer,
12
27
  broadcast_recv_input,
13
28
  launch_tp_servers,
@@ -16,7 +31,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
16
31
  from sglang.srt.utils import kill_parent_process
17
32
  from sglang.utils import get_exception_traceback
18
33
 
19
- logger = logging.getLogger("srt.controller")
34
+ logger = logging.getLogger(__name__)
20
35
 
21
36
 
22
37
  class ControllerSingle:
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """DetokenizerManager is a process that detokenizes the token ids."""
2
17
 
3
18
  import asyncio
@@ -10,8 +25,8 @@ import zmq
10
25
  import zmq.asyncio
11
26
 
12
27
  from sglang.srt.hf_transformers_utils import get_tokenizer
13
- from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
14
28
  from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
29
+ from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
15
30
  from sglang.srt.server_args import PortArgs, ServerArgs
16
31
  from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry
17
32
 
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """
2
17
  The definition of objects transfered between different
3
18
  processes (TokenizerManager, DetokenizerManager, Controller).
@@ -7,7 +22,7 @@ import uuid
7
22
  from dataclasses import dataclass
8
23
  from typing import Dict, List, Optional, Union
9
24
 
10
- from sglang.srt.managers.controller.infer_batch import BaseFinishReason
25
+ from sglang.srt.managers.schedule_batch import BaseFinishReason
11
26
  from sglang.srt.sampling_params import SamplingParams
12
27
 
13
28
 
@@ -20,7 +35,7 @@ class GenerateReqInput:
20
35
  # The image input. It can be a file name, a url, or base64 encoded string.
21
36
  # See also python/sglang/srt/utils.py:load_image.
22
37
  image_data: Optional[Union[List[str], str]] = None
23
- # The sampling_params.
38
+ # The sampling_params. See descriptions below.
24
39
  sampling_params: Union[List[Dict], Dict] = None
25
40
  # The request id.
26
41
  rid: Optional[Union[List[str], str]] = None
@@ -30,7 +45,7 @@ class GenerateReqInput:
30
45
  logprob_start_len: Optional[Union[List[int], int]] = None
31
46
  # The number of top logprobs to return.
32
47
  top_logprobs_num: Optional[Union[List[int], int]] = None
33
- # Whether to detokenize tokens in logprobs.
48
+ # Whether to detokenize tokens in text in the returned logprobs.
34
49
  return_text_in_logprobs: bool = False
35
50
  # Whether to stream output.
36
51
  stream: bool = False
@@ -64,8 +79,26 @@ class GenerateReqInput:
64
79
  if self.top_logprobs_num is None:
65
80
  self.top_logprobs_num = 0
66
81
  else:
67
-
68
- parallel_sample_num = self.sampling_params.get("n", 1)
82
+ parallel_sample_num_list = []
83
+ if isinstance(self.sampling_params, dict):
84
+ parallel_sample_num = self.sampling_params.get("n", 1)
85
+ elif isinstance(self.sampling_params, list):
86
+ for sp in self.sampling_params:
87
+ parallel_sample_num = sp.get("n", 1)
88
+ parallel_sample_num_list.append(parallel_sample_num)
89
+ parallel_sample_num = max(parallel_sample_num_list)
90
+ all_equal = all(
91
+ element == parallel_sample_num
92
+ for element in parallel_sample_num_list
93
+ )
94
+ if parallel_sample_num > 1 and (not all_equal):
95
+ ## TODO cope with the case that the parallel_sample_num is different for different samples
96
+ raise ValueError(
97
+ "The parallel_sample_num should be the same for all samples in sample params."
98
+ )
99
+ else:
100
+ parallel_sample_num = 1
101
+ self.parallel_sample_num = parallel_sample_num
69
102
 
70
103
  if parallel_sample_num != 1:
71
104
  # parallel sampling +1 represents the original prefill stage
@@ -1,46 +1,61 @@
1
- """Request scheduler heuristic."""
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """Request policy scheduler"""
2
17
 
3
18
  import random
4
19
  from collections import defaultdict
5
20
 
6
21
 
7
- class ScheduleHeuristic:
22
+ class PolicyScheduler:
8
23
  def __init__(
9
24
  self,
10
- schedule_heuristic,
25
+ policy,
11
26
  max_running_seqs,
12
27
  max_prefill_num_tokens,
13
28
  max_total_num_tokens,
14
29
  tree_cache,
15
30
  ):
16
- if tree_cache.disable and schedule_heuristic == "lpm":
31
+ if tree_cache.disable and policy == "lpm":
17
32
  # LMP is meaningless when the tree cache is disabled.
18
- schedule_heuristic = "fcfs"
33
+ policy = "fcfs"
19
34
 
20
- self.schedule_heuristic = schedule_heuristic
35
+ self.policy = policy
21
36
  self.max_running_seqs = max_running_seqs
22
37
  self.max_prefill_num_tokens = max_prefill_num_tokens
23
38
  self.max_total_num_tokens = max_total_num_tokens
24
39
  self.tree_cache = tree_cache
25
40
 
26
- def get_priority_queue(self, forward_queue):
27
- if self.schedule_heuristic == "lpm":
41
+ def get_priority_queue(self, waiting_queue):
42
+ if self.policy == "lpm":
28
43
  # longest prefix match
29
- forward_queue.sort(key=lambda x: -len(x.prefix_indices))
30
- return forward_queue
31
- elif self.schedule_heuristic == "fcfs":
44
+ waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
45
+ return waiting_queue
46
+ elif self.policy == "fcfs":
32
47
  # first come first serve
33
- return forward_queue
34
- elif self.schedule_heuristic == "lof":
48
+ return waiting_queue
49
+ elif self.policy == "lof":
35
50
  # longest output first
36
- forward_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
37
- return forward_queue
38
- elif self.schedule_heuristic == "random":
39
- random.shuffle(forward_queue)
40
- return forward_queue
41
- elif self.schedule_heuristic == "dfs-weight":
51
+ waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
52
+ return waiting_queue
53
+ elif self.policy == "random":
54
+ random.shuffle(waiting_queue)
55
+ return waiting_queue
56
+ elif self.policy == "dfs-weight":
42
57
  last_node_to_reqs = defaultdict(list)
43
- for req in forward_queue:
58
+ for req in waiting_queue:
44
59
  last_node_to_reqs[req.last_node].append(req)
45
60
 
46
61
  node_to_weight = defaultdict(int)
@@ -52,10 +67,10 @@ class ScheduleHeuristic:
52
67
  self.get_dfs_priority(
53
68
  self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
54
69
  )
55
- assert len(q) == len(forward_queue)
70
+ assert len(q) == len(waiting_queue)
56
71
  return q
57
72
  else:
58
- raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
73
+ raise ValueError(f"Unknown schedule_policy: {self.policy}")
59
74
 
60
75
  def calc_weight(self, cur_node, node_to_weight):
61
76
  for child in cur_node.children.values():