sglang 0.3.5.post1__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. sglang/bench_latency.py +1 -553
  2. sglang/bench_offline_throughput.py +337 -0
  3. sglang/bench_one_batch.py +474 -0
  4. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  5. sglang/bench_serving.py +115 -31
  6. sglang/check_env.py +3 -6
  7. sglang/srt/constrained/base_grammar_backend.py +4 -3
  8. sglang/srt/constrained/outlines_backend.py +39 -26
  9. sglang/srt/constrained/xgrammar_backend.py +58 -14
  10. sglang/srt/layers/activation.py +3 -0
  11. sglang/srt/layers/attention/flashinfer_backend.py +93 -48
  12. sglang/srt/layers/attention/triton_backend.py +9 -7
  13. sglang/srt/layers/custom_op_util.py +26 -0
  14. sglang/srt/layers/fused_moe/fused_moe.py +11 -4
  15. sglang/srt/layers/fused_moe/patch.py +4 -2
  16. sglang/srt/layers/layernorm.py +4 -0
  17. sglang/srt/layers/logits_processor.py +10 -10
  18. sglang/srt/layers/sampler.py +4 -8
  19. sglang/srt/layers/torchao_utils.py +2 -0
  20. sglang/srt/managers/data_parallel_controller.py +74 -9
  21. sglang/srt/managers/detokenizer_manager.py +1 -14
  22. sglang/srt/managers/io_struct.py +27 -0
  23. sglang/srt/managers/schedule_batch.py +104 -38
  24. sglang/srt/managers/schedule_policy.py +5 -1
  25. sglang/srt/managers/scheduler.py +210 -56
  26. sglang/srt/managers/session_controller.py +62 -0
  27. sglang/srt/managers/tokenizer_manager.py +38 -0
  28. sglang/srt/managers/tp_worker.py +12 -1
  29. sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
  30. sglang/srt/model_executor/cuda_graph_runner.py +43 -6
  31. sglang/srt/model_executor/forward_batch_info.py +109 -15
  32. sglang/srt/model_executor/model_runner.py +102 -43
  33. sglang/srt/model_parallel.py +98 -0
  34. sglang/srt/models/deepseek_v2.py +147 -44
  35. sglang/srt/models/gemma2.py +9 -8
  36. sglang/srt/models/llava.py +1 -1
  37. sglang/srt/models/llavavid.py +1 -1
  38. sglang/srt/models/olmo.py +3 -3
  39. sglang/srt/models/phi3_small.py +447 -0
  40. sglang/srt/models/qwen2_vl.py +13 -6
  41. sglang/srt/models/torch_native_llama.py +94 -78
  42. sglang/srt/openai_api/adapter.py +11 -4
  43. sglang/srt/openai_api/protocol.py +30 -27
  44. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  45. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  46. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  47. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  48. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  49. sglang/srt/sampling/sampling_batch_info.py +58 -57
  50. sglang/srt/sampling/sampling_params.py +3 -3
  51. sglang/srt/server.py +29 -2
  52. sglang/srt/server_args.py +97 -60
  53. sglang/srt/utils.py +103 -51
  54. sglang/test/runners.py +25 -6
  55. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  56. sglang/test/test_utils.py +33 -22
  57. sglang/version.py +1 -1
  58. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
  59. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/RECORD +62 -56
  60. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
  61. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
  62. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,8 @@
1
- import typing
1
+ from typing import List
2
2
 
3
3
  import torch
4
4
 
5
- from ..orchestrator import _BatchedPenalizer, _TokenIDs
5
+ from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
6
6
 
7
7
 
8
8
  class BatchedFrequencyPenalizer(_BatchedPenalizer):
@@ -44,9 +44,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
44
44
  )
45
45
 
46
46
  def _teardown(self):
47
- del self.frequency_penalties
48
- del self.cumulated_frequency_penalties
49
-
50
47
  self.frequency_penalties = None
51
48
  self.cumulated_frequency_penalties = None
52
49
 
@@ -62,9 +59,7 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
62
59
  logits -= self.cumulated_frequency_penalties
63
60
  return logits
64
61
 
65
- def _filter(
66
- self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
67
- ):
62
+ def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
68
63
  self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]
69
64
  self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
70
65
  indices_tensor_to_keep
@@ -1,8 +1,8 @@
1
- import typing
1
+ from typing import List
2
2
 
3
3
  import torch
4
4
 
5
- from ..orchestrator import _BatchedPenalizer, _TokenIDs
5
+ from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
6
6
 
7
7
 
8
8
  class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
@@ -70,10 +70,6 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
70
70
  )
71
71
 
72
72
  def _teardown(self):
73
- del self.min_new_tokens
74
- del self.stop_token_penalties
75
- del self.len_output_tokens
76
-
77
73
  self.min_new_tokens = None
78
74
  self.stop_token_penalties = None
79
75
  self.len_output_tokens = None
@@ -89,9 +85,7 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
89
85
  logits[mask] += self.stop_token_penalties[mask]
90
86
  return logits
91
87
 
92
- def _filter(
93
- self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
94
- ):
88
+ def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
95
89
  self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep]
96
90
  self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep]
97
91
  self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep]
@@ -1,8 +1,8 @@
1
- import typing
1
+ from typing import List
2
2
 
3
3
  import torch
4
4
 
5
- from ..orchestrator import _BatchedPenalizer, _TokenIDs
5
+ from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
6
6
 
7
7
 
8
8
  class BatchedPresencePenalizer(_BatchedPenalizer):
@@ -44,9 +44,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
44
44
  )
45
45
 
46
46
  def _teardown(self):
47
- del self.presence_penalties
48
- del self.cumulated_presence_penalties
49
-
50
47
  self.presence_penalties = None
51
48
  self.cumulated_presence_penalties = None
52
49
 
@@ -61,9 +58,7 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
61
58
  logits -= self.cumulated_presence_penalties
62
59
  return logits
63
60
 
64
- def _filter(
65
- self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
66
- ):
61
+ def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
67
62
  self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
68
63
  self.cumulated_presence_penalties = self.cumulated_presence_penalties[
69
64
  indices_tensor_to_keep
@@ -1,8 +1,8 @@
1
- import typing
1
+ from typing import List
2
2
 
3
3
  import torch
4
4
 
5
- from ..orchestrator import _BatchedPenalizer, _TokenIDs
5
+ from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
6
6
 
7
7
 
8
8
  class BatchedRepetitionPenalizer(_BatchedPenalizer):
@@ -44,9 +44,6 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
44
44
  )
45
45
 
46
46
  def _teardown(self):
47
- del self.repetition_penalties
48
- del self.cumulated_repetition_penalties
49
-
50
47
  self.repetition_penalties = None
51
48
  self.cumulated_repetition_penalties = None
52
49
 
@@ -65,9 +62,7 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
65
62
  logits * self.cumulated_repetition_penalties,
66
63
  )
67
64
 
68
- def _filter(
69
- self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
70
- ):
65
+ def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
71
66
  self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
72
67
  self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
73
68
  indices_tensor_to_keep
@@ -1,12 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import dataclasses
4
- from typing import TYPE_CHECKING, List, Optional
4
+ import logging
5
+ import threading
6
+ from typing import TYPE_CHECKING, Callable, List, Optional
5
7
 
6
8
  import torch
7
9
 
8
10
  import sglang.srt.sampling.penaltylib as penaltylib
9
11
 
12
+ logger = logging.getLogger(__name__)
13
+
14
+
10
15
  if TYPE_CHECKING:
11
16
  from sglang.srt.managers.schedule_batch import ScheduleBatch
12
17
 
@@ -27,10 +32,11 @@ class SamplingBatchInfo:
27
32
 
28
33
  # Bias Tensors
29
34
  vocab_size: int
35
+ grammars: Optional[List] = None
36
+ sampling_info_done: Optional[threading.Event] = None
30
37
  logit_bias: torch.Tensor = None
31
38
  vocab_mask: Optional[torch.Tensor] = None
32
-
33
- grammars: Optional[List] = None
39
+ apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
34
40
 
35
41
  # Penalizer
36
42
  penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
@@ -42,10 +48,7 @@ class SamplingBatchInfo:
42
48
 
43
49
  @classmethod
44
50
  def from_schedule_batch(
45
- cls,
46
- batch: ScheduleBatch,
47
- vocab_size: int,
48
- disable_penalizer: bool,
51
+ cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
49
52
  ):
50
53
  reqs = batch.reqs
51
54
  device = batch.device
@@ -73,12 +76,39 @@ class SamplingBatchInfo:
73
76
  top_ks=top_ks,
74
77
  min_ps=min_ps,
75
78
  need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
76
- is_all_greedy=top_ks.max().item() <= 1,
79
+ is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
77
80
  vocab_size=vocab_size,
78
81
  device=device,
79
82
  )
80
83
  # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
81
84
 
85
+ if enable_overlap_schedule:
86
+ # TODO (lianmin): Some penalizers such as frequency and presence depend on model outputs,
87
+ # so it is kind of tricky to make it work with overlap scheduler.
88
+ # It requires correcly updating the penalty logits before the sampling and syncing the events.
89
+ # We will support them later.
90
+ penalizers = {
91
+ penaltylib.BatchedMinNewTokensPenalizer,
92
+ }
93
+ if (
94
+ any(req.sampling_params.frequency_penalty != 0.0 for req in reqs)
95
+ or any(req.sampling_params.presence_penalty != 0.0 for req in reqs)
96
+ or any(req.sampling_params.repetition_penalty != 1.0 for req in reqs)
97
+ ):
98
+ logger.warning(
99
+ "frequency_penalty, presence_penalty, and repetition_penalty are not supported "
100
+ "when using the default overlap scheduler. They will be ignored. "
101
+ "Please add `--disable-overlap` when launching the server if you need these features. "
102
+ "The speed will be slower in that case."
103
+ )
104
+ else:
105
+ penalizers = {
106
+ penaltylib.BatchedFrequencyPenalizer,
107
+ penaltylib.BatchedMinNewTokensPenalizer,
108
+ penaltylib.BatchedPresencePenalizer,
109
+ penaltylib.BatchedRepetitionPenalizer,
110
+ }
111
+
82
112
  # Each penalizers will do nothing if they evaluate themselves as not required by looking at
83
113
  # the sampling_params of the requests (See {_is_required()} of each penalizers). So this
84
114
  # should not add hefty computation overhead other than simple checks.
@@ -86,20 +116,12 @@ class SamplingBatchInfo:
86
116
  # While we choose not to even create the class instances if they are not required, this
87
117
  # could add additional complexity to the {ScheduleBatch} class, especially we need to
88
118
  # handle {filter_batch()} and {merge_batch()} cases as well.
89
- if disable_penalizer:
90
- ret.penalizer_orchestrator = None
91
- else:
92
- ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
93
- vocab_size=vocab_size,
94
- batch=batch,
95
- device=batch.device,
96
- Penalizers={
97
- penaltylib.BatchedFrequencyPenalizer,
98
- penaltylib.BatchedMinNewTokensPenalizer,
99
- penaltylib.BatchedPresencePenalizer,
100
- penaltylib.BatchedRepetitionPenalizer,
101
- },
102
- )
119
+ ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
120
+ vocab_size=vocab_size,
121
+ batch=batch,
122
+ device=batch.device,
123
+ Penalizers=penalizers,
124
+ )
103
125
 
104
126
  # Handle logit bias but only allocate when needed
105
127
  ret.logit_bias = None
@@ -110,9 +132,6 @@ class SamplingBatchInfo:
110
132
  return len(self.temperatures)
111
133
 
112
134
  def update_penalties(self):
113
- if not self.penalizer_orchestrator:
114
- return
115
-
116
135
  self.scaling_penalties = None
117
136
  self.linear_penalties = None
118
137
 
@@ -133,23 +152,28 @@ class SamplingBatchInfo:
133
152
  self.linear_penalties = penalizer.apply(self.linear_penalties)
134
153
 
135
154
  def update_regex_vocab_mask(self):
136
- if not self.grammars or not any(grammar for grammar in self.grammars):
155
+ if not self.grammars:
137
156
  self.vocab_mask = None
157
+ self.apply_mask = None
138
158
  return
139
159
 
140
- self.vocab_mask = torch.zeros(
141
- len(self.temperatures),
142
- self.vocab_size,
143
- dtype=torch.bool,
160
+ # find a grammar from the list
161
+ grammar = next(grammar for grammar in self.grammars if grammar)
162
+
163
+ # maybe we can reuse the existing mask?
164
+ self.vocab_mask = grammar.allocate_vocab_mask(
165
+ vocab_size=self.vocab_size,
166
+ batch_size=len(self.temperatures),
144
167
  device=self.device,
145
168
  )
169
+ self.apply_mask = type(grammar).apply_vocab_mask # force to use static method
170
+
146
171
  for i, grammar in enumerate(self.grammars):
147
172
  if grammar is not None:
148
- grammar.fill_vocab_mask(self.vocab_mask[i])
173
+ grammar.fill_vocab_mask(self.vocab_mask, i)
149
174
 
150
175
  def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
151
- if self.penalizer_orchestrator:
152
- self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
176
+ self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
153
177
 
154
178
  for item in [
155
179
  "temperatures",
@@ -188,8 +212,7 @@ class SamplingBatchInfo:
188
212
  return None
189
213
 
190
214
  def merge_batch(self, other: "SamplingBatchInfo"):
191
- if self.penalizer_orchestrator:
192
- self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
215
+ self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
193
216
 
194
217
  for item in [
195
218
  "temperatures",
@@ -205,25 +228,3 @@ class SamplingBatchInfo:
205
228
  self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
206
229
  self.logit_bias, other.logit_bias, len(self), len(other), self.device
207
230
  )
208
-
209
- def copy(self):
210
- return SamplingBatchInfo(
211
- temperatures=self.temperatures,
212
- top_ps=self.top_ps,
213
- top_ks=self.top_ks,
214
- min_ps=self.min_ps,
215
- is_all_greedy=self.is_all_greedy,
216
- need_min_p_sampling=self.need_min_p_sampling,
217
- vocab_size=self.vocab_size,
218
- device=self.device,
219
- )
220
-
221
- def to(self, device: str):
222
- for item in [
223
- "temperatures",
224
- "top_ps",
225
- "top_ks",
226
- "min_ps",
227
- ]:
228
- value = getattr(self, item)
229
- setattr(self, item, value.to(device, non_blocking=True))
@@ -24,7 +24,6 @@ class SamplingParams:
24
24
  def __init__(
25
25
  self,
26
26
  max_new_tokens: int = 128,
27
- min_new_tokens: int = 0,
28
27
  stop: Optional[Union[str, List[str]]] = None,
29
28
  stop_token_ids: Optional[List[int]] = None,
30
29
  temperature: float = 1.0,
@@ -34,13 +33,14 @@ class SamplingParams:
34
33
  frequency_penalty: float = 0.0,
35
34
  presence_penalty: float = 0.0,
36
35
  repetition_penalty: float = 1.0,
37
- ignore_eos: bool = False,
38
- skip_special_tokens: bool = True,
36
+ min_new_tokens: int = 0,
39
37
  spaces_between_special_tokens: bool = True,
40
38
  regex: Optional[str] = None,
41
39
  n: int = 1,
42
40
  json_schema: Optional[str] = None,
43
41
  no_stop_trim: bool = False,
42
+ ignore_eos: bool = False,
43
+ skip_special_tokens: bool = True,
44
44
  ) -> None:
45
45
  self.temperature = temperature
46
46
  self.top_p = top_p
sglang/srt/server.py CHANGED
@@ -50,8 +50,10 @@ from sglang.srt.managers.data_parallel_controller import (
50
50
  )
51
51
  from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
52
52
  from sglang.srt.managers.io_struct import (
53
+ CloseSessionReqInput,
53
54
  EmbeddingReqInput,
54
55
  GenerateReqInput,
56
+ OpenSessionReqInput,
55
57
  UpdateWeightReqInput,
56
58
  )
57
59
  from sglang.srt.managers.scheduler import run_scheduler_process
@@ -139,6 +141,7 @@ async def get_model_info():
139
141
  """Get the model information."""
140
142
  result = {
141
143
  "model_path": tokenizer_manager.model_path,
144
+ "tokenizer_path": tokenizer_manager.server_args.tokenizer_path,
142
145
  "is_generation": tokenizer_manager.is_generation,
143
146
  }
144
147
  return result
@@ -214,6 +217,30 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
214
217
  )
215
218
 
216
219
 
220
+ @app.api_route("/open_session", methods=["GET", "POST"])
221
+ async def open_session(obj: OpenSessionReqInput, request: Request):
222
+ """Open a session, and return its unique session id."""
223
+ try:
224
+ session_id = await tokenizer_manager.open_session(obj, request)
225
+ return session_id
226
+ except Exception as e:
227
+ return ORJSONResponse(
228
+ {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
229
+ )
230
+
231
+
232
+ @app.api_route("/close_session", methods=["GET", "POST"])
233
+ async def close_session(obj: CloseSessionReqInput, request: Request):
234
+ """Close the session"""
235
+ try:
236
+ await tokenizer_manager.close_session(obj, request)
237
+ return Response(status_code=200)
238
+ except Exception as e:
239
+ return ORJSONResponse(
240
+ {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
241
+ )
242
+
243
+
217
244
  @time_func_latency
218
245
  async def generate_request(obj: GenerateReqInput, request: Request):
219
246
  """Handle a generate request."""
@@ -391,7 +418,7 @@ def launch_engine(
391
418
  )
392
419
  for tp_rank in tp_rank_range:
393
420
  reader, writer = mp.Pipe(duplex=False)
394
- gpu_id = tp_rank % tp_size_per_node
421
+ gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node
395
422
  proc = mp.Process(
396
423
  target=run_scheduler_process,
397
424
  args=(server_args, port_args, gpu_id, tp_rank, None, writer),
@@ -768,7 +795,7 @@ class Engine:
768
795
  self,
769
796
  # The input prompt. It can be a single prompt or a batch of prompts.
770
797
  prompt: Optional[Union[List[str], str]] = None,
771
- sampling_params: Optional[Dict] = None,
798
+ sampling_params: Optional[Union[List[Dict], Dict]] = None,
772
799
  # The token ids for text; one can either specify text or input_ids.
773
800
  input_ids: Optional[Union[List[List[int]], List[int]]] = None,
774
801
  return_logprob: Optional[Union[List[bool], bool]] = False,