sglang 0.3.5.post2__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +2 -2
  3. sglang/bench_latency.py +1 -553
  4. sglang/bench_offline_throughput.py +48 -20
  5. sglang/bench_one_batch.py +472 -0
  6. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  7. sglang/bench_serving.py +125 -6
  8. sglang/check_env.py +3 -6
  9. sglang/lang/backend/base_backend.py +1 -1
  10. sglang/lang/backend/runtime_endpoint.py +2 -2
  11. sglang/srt/configs/model_config.py +13 -14
  12. sglang/srt/constrained/__init__.py +13 -14
  13. sglang/srt/constrained/base_grammar_backend.py +13 -15
  14. sglang/srt/constrained/outlines_backend.py +28 -17
  15. sglang/srt/constrained/outlines_jump_forward.py +13 -15
  16. sglang/srt/constrained/xgrammar_backend.py +47 -58
  17. sglang/srt/conversation.py +13 -15
  18. sglang/srt/hf_transformers_utils.py +13 -15
  19. sglang/srt/layers/activation.py +16 -13
  20. sglang/srt/layers/attention/flashinfer_backend.py +106 -54
  21. sglang/srt/layers/attention/triton_backend.py +9 -7
  22. sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
  23. sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
  24. sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
  25. sglang/srt/layers/custom_op_util.py +25 -0
  26. sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
  27. sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +11 -4
  28. sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
  29. sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
  30. sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
  31. sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
  32. sglang/srt/layers/fused_moe_triton/layer.py +633 -0
  33. sglang/srt/layers/layernorm.py +17 -15
  34. sglang/srt/layers/logits_processor.py +23 -25
  35. sglang/srt/layers/quantization/__init__.py +77 -17
  36. sglang/srt/layers/radix_attention.py +13 -15
  37. sglang/srt/layers/rotary_embedding.py +13 -13
  38. sglang/srt/layers/sampler.py +4 -8
  39. sglang/srt/layers/torchao_utils.py +2 -0
  40. sglang/srt/lora/lora.py +13 -14
  41. sglang/srt/lora/lora_config.py +13 -14
  42. sglang/srt/lora/lora_manager.py +22 -24
  43. sglang/srt/managers/data_parallel_controller.py +98 -27
  44. sglang/srt/managers/detokenizer_manager.py +13 -15
  45. sglang/srt/managers/io_struct.py +63 -21
  46. sglang/srt/managers/schedule_batch.py +154 -59
  47. sglang/srt/managers/schedule_policy.py +18 -16
  48. sglang/srt/managers/scheduler.py +278 -109
  49. sglang/srt/managers/session_controller.py +61 -0
  50. sglang/srt/managers/tokenizer_manager.py +63 -18
  51. sglang/srt/managers/tp_worker.py +25 -16
  52. sglang/srt/managers/tp_worker_overlap_thread.py +62 -67
  53. sglang/srt/metrics/collector.py +13 -15
  54. sglang/srt/metrics/func_timer.py +13 -15
  55. sglang/srt/mm_utils.py +13 -14
  56. sglang/srt/model_executor/cuda_graph_runner.py +63 -25
  57. sglang/srt/model_executor/forward_batch_info.py +128 -32
  58. sglang/srt/model_executor/model_runner.py +132 -64
  59. sglang/srt/model_parallel.py +98 -0
  60. sglang/srt/models/chatglm.py +15 -16
  61. sglang/srt/models/commandr.py +15 -16
  62. sglang/srt/models/dbrx.py +15 -16
  63. sglang/srt/models/deepseek.py +15 -15
  64. sglang/srt/models/deepseek_v2.py +162 -59
  65. sglang/srt/models/exaone.py +14 -15
  66. sglang/srt/models/gemma.py +14 -14
  67. sglang/srt/models/gemma2.py +31 -25
  68. sglang/srt/models/gemma2_reward.py +13 -14
  69. sglang/srt/models/gpt_bigcode.py +14 -14
  70. sglang/srt/models/grok.py +15 -15
  71. sglang/srt/models/internlm2.py +13 -15
  72. sglang/srt/models/internlm2_reward.py +13 -14
  73. sglang/srt/models/llama.py +21 -21
  74. sglang/srt/models/llama_classification.py +13 -14
  75. sglang/srt/models/llama_reward.py +13 -14
  76. sglang/srt/models/llava.py +14 -16
  77. sglang/srt/models/llavavid.py +14 -16
  78. sglang/srt/models/minicpm.py +13 -15
  79. sglang/srt/models/minicpm3.py +13 -15
  80. sglang/srt/models/mistral.py +13 -15
  81. sglang/srt/models/mixtral.py +15 -15
  82. sglang/srt/models/mixtral_quant.py +14 -14
  83. sglang/srt/models/olmo.py +22 -20
  84. sglang/srt/models/olmoe.py +23 -20
  85. sglang/srt/models/phi3_small.py +447 -0
  86. sglang/srt/models/qwen.py +14 -14
  87. sglang/srt/models/qwen2.py +22 -19
  88. sglang/srt/models/qwen2_moe.py +17 -18
  89. sglang/srt/models/qwen2_vl.py +13 -6
  90. sglang/srt/models/stablelm.py +18 -16
  91. sglang/srt/models/torch_native_llama.py +107 -93
  92. sglang/srt/models/xverse.py +13 -14
  93. sglang/srt/models/xverse_moe.py +15 -16
  94. sglang/srt/models/yivl.py +13 -15
  95. sglang/srt/openai_api/adapter.py +19 -17
  96. sglang/srt/openai_api/protocol.py +14 -16
  97. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  98. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  99. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  100. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  101. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  102. sglang/srt/sampling/sampling_batch_info.py +61 -57
  103. sglang/srt/sampling/sampling_params.py +14 -16
  104. sglang/srt/server.py +86 -35
  105. sglang/srt/server_args.py +96 -80
  106. sglang/srt/utils.py +266 -68
  107. sglang/test/few_shot_gsm8k.py +8 -4
  108. sglang/test/runners.py +38 -20
  109. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  110. sglang/test/test_utils.py +31 -20
  111. sglang/version.py +1 -1
  112. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
  113. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +66 -57
  114. sglang-0.3.6.post1.dist-info/RECORD +164 -0
  115. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +1 -1
  116. sglang/srt/layers/fused_moe/__init__.py +0 -1
  117. sglang-0.3.5.post2.dist-info/RECORD +0 -156
  118. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -1,40 +1,34 @@
1
1
  import abc
2
2
  import dataclasses
3
- import typing
3
+ from typing import List, Set, Type, Union
4
4
 
5
5
  import torch
6
6
 
7
7
 
8
8
  @dataclasses.dataclass
9
9
  class _ReqLike:
10
- origin_input_ids: typing.Union[torch.Tensor, typing.List[int]]
10
+ origin_input_ids: List[int]
11
11
 
12
12
 
13
13
  @dataclasses.dataclass
14
14
  class _BatchLike:
15
- reqs: typing.List[_ReqLike]
15
+ reqs: List[_ReqLike]
16
16
 
17
17
  def batch_size(self):
18
18
  return len(self.reqs)
19
19
 
20
20
 
21
21
  class BatchedPenalizerOrchestrator:
22
- batch: _BatchLike
23
- device: str
24
- vocab_size: int
25
- penalizers: typing.Dict[typing.Type["_BatchedPenalizer"], "_BatchedPenalizer"]
26
-
27
22
  def __init__(
28
23
  self,
29
24
  vocab_size: int,
30
25
  batch: _BatchLike,
31
26
  device: str,
32
- Penalizers: typing.Set[typing.Type["_BatchedPenalizer"]],
27
+ Penalizers: Set[Type["_BatchedPenalizer"]],
33
28
  ):
34
29
  self.vocab_size = vocab_size
35
30
  self.batch = batch
36
31
  self.device = device
37
-
38
32
  self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
39
33
 
40
34
  is_required = False
@@ -43,10 +37,12 @@ class BatchedPenalizerOrchestrator:
43
37
  is_required |= pen_is_required
44
38
  self.is_required = is_required
45
39
 
40
+ input_ids = [
41
+ torch.tensor(req.origin_input_ids, dtype=torch.int64, device=self.device)
42
+ for req in self.reqs()
43
+ ]
46
44
  if self.is_required:
47
- self.cumulate_input_tokens(
48
- input_ids=[req.origin_input_ids for req in self.reqs()]
49
- )
45
+ self.cumulate_input_tokens(input_ids=input_ids)
50
46
 
51
47
  def reqs(self):
52
48
  return self.batch.reqs
@@ -54,34 +50,24 @@ class BatchedPenalizerOrchestrator:
54
50
  def batch_size(self):
55
51
  return self.batch.batch_size()
56
52
 
57
- def cumulate_input_tokens(
58
- self,
59
- input_ids: typing.Union[
60
- typing.List[torch.Tensor], typing.List[typing.List[int]]
61
- ],
62
- ):
53
+ def cumulate_input_tokens(self, input_ids: List[torch.Tensor]):
63
54
  """
64
55
  Feed the input tokens to the penalizers.
65
56
 
66
57
  Args:
67
- input_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The input tokens.
58
+ input_ids (List[torch.Tensor]): The input tokens.
68
59
  """
69
60
  token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
70
61
 
71
62
  for penalizer in self.penalizers.values():
72
63
  penalizer.cumulate_input_tokens(input_ids=token_ids)
73
64
 
74
- def cumulate_output_tokens(
75
- self,
76
- output_ids: typing.Union[
77
- typing.List[torch.Tensor], typing.List[typing.List[int]]
78
- ],
79
- ):
65
+ def cumulate_output_tokens(self, output_ids: torch.Tensor):
80
66
  """
81
67
  Feed the output tokens to the penalizers.
82
68
 
83
69
  Args:
84
- output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.
70
+ output_ids (torch.Tensor): The output tokens.
85
71
  """
86
72
  if not self.is_required:
87
73
  return
@@ -112,14 +98,14 @@ class BatchedPenalizerOrchestrator:
112
98
 
113
99
  def filter(
114
100
  self,
115
- indices_to_keep: typing.List[int],
101
+ indices_to_keep: List[int],
116
102
  indices_tensor_to_keep: torch.Tensor = None,
117
103
  ):
118
104
  """
119
105
  Filter the penalizers based on the indices to keep in the batch.
120
106
 
121
107
  Args:
122
- indices_to_keep (typing.List[int]): List of indices to keep in the batch.
108
+ indices_to_keep (List[int]): List of indices to keep in the batch.
123
109
  indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
124
110
  """
125
111
  if not self.is_required:
@@ -174,32 +160,18 @@ class _TokenIDs:
174
160
 
175
161
  Attributes:
176
162
  orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
177
- token_ids (typing.Union[torch.Tensor, typing.List[torch.Tensor]]): The token IDs.
163
+ token_ids (Union[torch.Tensor, List[torch.Tensor]]): The token IDs.
178
164
  cached_counts (torch.Tensor): The cached occurrence count tensor.
179
165
  """
180
166
 
181
- orchestrator: BatchedPenalizerOrchestrator
182
- token_ids: typing.Union[torch.Tensor, typing.List[torch.Tensor]]
183
- cached_counts: torch.Tensor = None
184
-
185
167
  def __init__(
186
168
  self,
187
169
  orchestrator: BatchedPenalizerOrchestrator,
188
- token_ids: typing.Union[
189
- typing.List[torch.Tensor], typing.List[typing.List[int]]
190
- ],
170
+ token_ids: Union[torch.Tensor, List[torch.Tensor]],
191
171
  ):
192
172
  self.orchestrator = orchestrator
193
-
194
- if not isinstance(token_ids[0], torch.Tensor):
195
- token_ids = [
196
- torch.tensor(
197
- data=ids, dtype=torch.int64, device=self.orchestrator.device
198
- )
199
- for ids in token_ids
200
- ]
201
-
202
173
  self.token_ids = token_ids
174
+ self.cached_counts = None
203
175
 
204
176
  def occurrence_count(self) -> torch.Tensor:
205
177
  """
@@ -213,30 +185,34 @@ class _TokenIDs:
213
185
 
214
186
  token_ids = self.token_ids
215
187
 
216
- if isinstance(token_ids, torch.Tensor):
217
- token_ids = token_ids.unsqueeze(1)
218
-
219
- # needs to be long to be used as index in scatter_add
220
- if token_ids.dtype != torch.int64:
221
- token_ids = token_ids.to(torch.int64)
222
-
223
- padded_token_ids = torch.nn.utils.rnn.pad_sequence(
224
- sequences=token_ids,
225
- batch_first=True,
226
- padding_value=self.orchestrator.vocab_size,
227
- )
228
-
229
- self.cached_counts = torch.zeros(
230
- size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
231
- dtype=torch.int64,
232
- device=self.orchestrator.device,
233
- ).scatter_add_(
234
- dim=1,
235
- index=padded_token_ids,
236
- src=torch.ones_like(padded_token_ids),
237
- )[
238
- :, : self.orchestrator.vocab_size
239
- ]
188
+ if isinstance(token_ids, list):
189
+ # TODO: optimize this part
190
+ padded_token_ids = torch.nn.utils.rnn.pad_sequence(
191
+ sequences=token_ids,
192
+ batch_first=True,
193
+ padding_value=self.orchestrator.vocab_size,
194
+ )
195
+ self.cached_counts = torch.zeros(
196
+ size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
197
+ dtype=torch.int64,
198
+ device=self.orchestrator.device,
199
+ ).scatter_add_(
200
+ dim=1,
201
+ index=padded_token_ids,
202
+ src=torch.ones_like(padded_token_ids),
203
+ )[
204
+ :, : self.orchestrator.vocab_size
205
+ ]
206
+ else:
207
+ # TODO: optimize this part. We do not need to create this big tensor every time.
208
+ # We can directly apply the results on the logits.
209
+ self.cached_counts = torch.zeros(
210
+ size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size),
211
+ device=self.orchestrator.device,
212
+ )
213
+ self.cached_counts[
214
+ torch.arange(len(token_ids), device=self.orchestrator.device), token_ids
215
+ ] = 1
240
216
 
241
217
  return self.cached_counts
242
218
 
@@ -246,11 +222,9 @@ class _BatchedPenalizer(abc.ABC):
246
222
  An abstract class for a batched penalizer.
247
223
  """
248
224
 
249
- orchestrator: BatchedPenalizerOrchestrator
250
- _is_prepared: bool = False
251
-
252
225
  def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
253
226
  self.orchestrator = orchestrator
227
+ self._is_prepared = False
254
228
 
255
229
  def is_prepared(self) -> bool:
256
230
  return self._is_prepared
@@ -293,9 +267,7 @@ class _BatchedPenalizer(abc.ABC):
293
267
 
294
268
  return self._apply(logits=logits)
295
269
 
296
- def filter(
297
- self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
298
- ):
270
+ def filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
299
271
  if not self.is_prepared():
300
272
  return
301
273
 
@@ -360,9 +332,7 @@ class _BatchedPenalizer(abc.ABC):
360
332
  pass
361
333
 
362
334
  @abc.abstractmethod
363
- def _filter(
364
- self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
365
- ):
335
+ def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
366
336
  """
367
337
  Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
368
338
  """
@@ -1,8 +1,8 @@
1
- import typing
1
+ from typing import List
2
2
 
3
3
  import torch
4
4
 
5
- from ..orchestrator import _BatchedPenalizer, _TokenIDs
5
+ from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
6
6
 
7
7
 
8
8
  class BatchedFrequencyPenalizer(_BatchedPenalizer):
@@ -44,9 +44,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
44
44
  )
45
45
 
46
46
  def _teardown(self):
47
- del self.frequency_penalties
48
- del self.cumulated_frequency_penalties
49
-
50
47
  self.frequency_penalties = None
51
48
  self.cumulated_frequency_penalties = None
52
49
 
@@ -62,9 +59,7 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
62
59
  logits -= self.cumulated_frequency_penalties
63
60
  return logits
64
61
 
65
- def _filter(
66
- self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
67
- ):
62
+ def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
68
63
  self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]
69
64
  self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
70
65
  indices_tensor_to_keep
@@ -1,8 +1,8 @@
1
- import typing
1
+ from typing import List
2
2
 
3
3
  import torch
4
4
 
5
- from ..orchestrator import _BatchedPenalizer, _TokenIDs
5
+ from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
6
6
 
7
7
 
8
8
  class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
@@ -70,10 +70,6 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
70
70
  )
71
71
 
72
72
  def _teardown(self):
73
- del self.min_new_tokens
74
- del self.stop_token_penalties
75
- del self.len_output_tokens
76
-
77
73
  self.min_new_tokens = None
78
74
  self.stop_token_penalties = None
79
75
  self.len_output_tokens = None
@@ -89,9 +85,7 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
89
85
  logits[mask] += self.stop_token_penalties[mask]
90
86
  return logits
91
87
 
92
- def _filter(
93
- self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
94
- ):
88
+ def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
95
89
  self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep]
96
90
  self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep]
97
91
  self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep]
@@ -1,8 +1,8 @@
1
- import typing
1
+ from typing import List
2
2
 
3
3
  import torch
4
4
 
5
- from ..orchestrator import _BatchedPenalizer, _TokenIDs
5
+ from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
6
6
 
7
7
 
8
8
  class BatchedPresencePenalizer(_BatchedPenalizer):
@@ -44,9 +44,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
44
44
  )
45
45
 
46
46
  def _teardown(self):
47
- del self.presence_penalties
48
- del self.cumulated_presence_penalties
49
-
50
47
  self.presence_penalties = None
51
48
  self.cumulated_presence_penalties = None
52
49
 
@@ -61,9 +58,7 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
61
58
  logits -= self.cumulated_presence_penalties
62
59
  return logits
63
60
 
64
- def _filter(
65
- self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
66
- ):
61
+ def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
67
62
  self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
68
63
  self.cumulated_presence_penalties = self.cumulated_presence_penalties[
69
64
  indices_tensor_to_keep
@@ -1,8 +1,8 @@
1
- import typing
1
+ from typing import List
2
2
 
3
3
  import torch
4
4
 
5
- from ..orchestrator import _BatchedPenalizer, _TokenIDs
5
+ from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
6
6
 
7
7
 
8
8
  class BatchedRepetitionPenalizer(_BatchedPenalizer):
@@ -44,9 +44,6 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
44
44
  )
45
45
 
46
46
  def _teardown(self):
47
- del self.repetition_penalties
48
- del self.cumulated_repetition_penalties
49
-
50
47
  self.repetition_penalties = None
51
48
  self.cumulated_repetition_penalties = None
52
49
 
@@ -65,9 +62,7 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
65
62
  logits * self.cumulated_repetition_penalties,
66
63
  )
67
64
 
68
- def _filter(
69
- self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
70
- ):
65
+ def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
71
66
  self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
72
67
  self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
73
68
  indices_tensor_to_keep
@@ -1,12 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import dataclasses
4
- from typing import TYPE_CHECKING, List, Optional
4
+ import logging
5
+ import threading
6
+ from typing import TYPE_CHECKING, Callable, List, Optional
5
7
 
6
8
  import torch
7
9
 
8
10
  import sglang.srt.sampling.penaltylib as penaltylib
9
11
 
12
+ logger = logging.getLogger(__name__)
13
+
14
+
10
15
  if TYPE_CHECKING:
11
16
  from sglang.srt.managers.schedule_batch import ScheduleBatch
12
17
 
@@ -27,10 +32,11 @@ class SamplingBatchInfo:
27
32
 
28
33
  # Bias Tensors
29
34
  vocab_size: int
35
+ grammars: Optional[List] = None
36
+ sampling_info_done: Optional[threading.Event] = None
30
37
  logit_bias: torch.Tensor = None
31
38
  vocab_mask: Optional[torch.Tensor] = None
32
-
33
- grammars: Optional[List] = None
39
+ apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
34
40
 
35
41
  # Penalizer
36
42
  penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
@@ -42,10 +48,7 @@ class SamplingBatchInfo:
42
48
 
43
49
  @classmethod
44
50
  def from_schedule_batch(
45
- cls,
46
- batch: ScheduleBatch,
47
- vocab_size: int,
48
- disable_penalizer: bool,
51
+ cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
49
52
  ):
50
53
  reqs = batch.reqs
51
54
  device = batch.device
@@ -73,12 +76,39 @@ class SamplingBatchInfo:
73
76
  top_ks=top_ks,
74
77
  min_ps=min_ps,
75
78
  need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
76
- is_all_greedy=top_ks.max().item() <= 1,
79
+ is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
77
80
  vocab_size=vocab_size,
78
81
  device=device,
79
82
  )
80
83
  # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
81
84
 
85
+ if enable_overlap_schedule:
86
+ # TODO (lianmin): Some penalizers such as frequency and presence depend on model outputs,
87
+ # so it is kind of tricky to make it work with overlap scheduler.
88
+ # It requires correcly updating the penalty logits before the sampling and syncing the events.
89
+ # We will support them later.
90
+ penalizers = {
91
+ penaltylib.BatchedMinNewTokensPenalizer,
92
+ }
93
+ if (
94
+ any(req.sampling_params.frequency_penalty != 0.0 for req in reqs)
95
+ or any(req.sampling_params.presence_penalty != 0.0 for req in reqs)
96
+ or any(req.sampling_params.repetition_penalty != 1.0 for req in reqs)
97
+ ):
98
+ logger.warning(
99
+ "frequency_penalty, presence_penalty, and repetition_penalty are not supported "
100
+ "when using the default overlap scheduler. They will be ignored. "
101
+ "Please add `--disable-overlap` when launching the server if you need these features. "
102
+ "The speed will be slower in that case."
103
+ )
104
+ else:
105
+ penalizers = {
106
+ penaltylib.BatchedFrequencyPenalizer,
107
+ penaltylib.BatchedMinNewTokensPenalizer,
108
+ penaltylib.BatchedPresencePenalizer,
109
+ penaltylib.BatchedRepetitionPenalizer,
110
+ }
111
+
82
112
  # Each penalizers will do nothing if they evaluate themselves as not required by looking at
83
113
  # the sampling_params of the requests (See {_is_required()} of each penalizers). So this
84
114
  # should not add hefty computation overhead other than simple checks.
@@ -86,20 +116,12 @@ class SamplingBatchInfo:
86
116
  # While we choose not to even create the class instances if they are not required, this
87
117
  # could add additional complexity to the {ScheduleBatch} class, especially we need to
88
118
  # handle {filter_batch()} and {merge_batch()} cases as well.
89
- if disable_penalizer:
90
- ret.penalizer_orchestrator = None
91
- else:
92
- ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
93
- vocab_size=vocab_size,
94
- batch=batch,
95
- device=batch.device,
96
- Penalizers={
97
- penaltylib.BatchedFrequencyPenalizer,
98
- penaltylib.BatchedMinNewTokensPenalizer,
99
- penaltylib.BatchedPresencePenalizer,
100
- penaltylib.BatchedRepetitionPenalizer,
101
- },
102
- )
119
+ ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
120
+ vocab_size=vocab_size,
121
+ batch=batch,
122
+ device=batch.device,
123
+ Penalizers=penalizers,
124
+ )
103
125
 
104
126
  # Handle logit bias but only allocate when needed
105
127
  ret.logit_bias = None
@@ -110,9 +132,6 @@ class SamplingBatchInfo:
110
132
  return len(self.temperatures)
111
133
 
112
134
  def update_penalties(self):
113
- if not self.penalizer_orchestrator:
114
- return
115
-
116
135
  self.scaling_penalties = None
117
136
  self.linear_penalties = None
118
137
 
@@ -133,23 +152,31 @@ class SamplingBatchInfo:
133
152
  self.linear_penalties = penalizer.apply(self.linear_penalties)
134
153
 
135
154
  def update_regex_vocab_mask(self):
136
- if not self.grammars or not any(grammar for grammar in self.grammars):
155
+ if not self.grammars:
137
156
  self.vocab_mask = None
157
+ self.apply_mask = None
138
158
  return
139
159
 
140
- self.vocab_mask = torch.zeros(
141
- len(self.temperatures),
142
- self.vocab_size,
143
- dtype=torch.bool,
160
+ # find a grammar from the list
161
+ grammar = next(grammar for grammar in self.grammars if grammar)
162
+
163
+ # maybe we can reuse the existing mask?
164
+ self.vocab_mask = grammar.allocate_vocab_mask(
165
+ vocab_size=self.vocab_size,
166
+ batch_size=len(self.temperatures),
144
167
  device=self.device,
145
168
  )
169
+ self.apply_mask = type(grammar).apply_vocab_mask # force to use static method
170
+
146
171
  for i, grammar in enumerate(self.grammars):
147
172
  if grammar is not None:
148
- grammar.fill_vocab_mask(self.vocab_mask[i])
173
+ try:
174
+ grammar.fill_vocab_mask(self.vocab_mask, i)
175
+ except RuntimeError:
176
+ continue
149
177
 
150
178
  def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
151
- if self.penalizer_orchestrator:
152
- self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
179
+ self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
153
180
 
154
181
  for item in [
155
182
  "temperatures",
@@ -188,8 +215,7 @@ class SamplingBatchInfo:
188
215
  return None
189
216
 
190
217
  def merge_batch(self, other: "SamplingBatchInfo"):
191
- if self.penalizer_orchestrator:
192
- self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
218
+ self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
193
219
 
194
220
  for item in [
195
221
  "temperatures",
@@ -205,25 +231,3 @@ class SamplingBatchInfo:
205
231
  self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
206
232
  self.logit_bias, other.logit_bias, len(self), len(other), self.device
207
233
  )
208
-
209
- def copy(self):
210
- return SamplingBatchInfo(
211
- temperatures=self.temperatures,
212
- top_ps=self.top_ps,
213
- top_ks=self.top_ks,
214
- min_ps=self.min_ps,
215
- is_all_greedy=self.is_all_greedy,
216
- need_min_p_sampling=self.need_min_p_sampling,
217
- vocab_size=self.vocab_size,
218
- device=self.device,
219
- )
220
-
221
- def to(self, device: str):
222
- for item in [
223
- "temperatures",
224
- "top_ps",
225
- "top_ks",
226
- "min_ps",
227
- ]:
228
- value = getattr(self, item)
229
- setattr(self, item, value.to(device, non_blocking=True))
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """Sampling parameters for text generation."""
17
15
 
18
16
  from typing import List, Optional, Union
@@ -24,7 +22,6 @@ class SamplingParams:
24
22
  def __init__(
25
23
  self,
26
24
  max_new_tokens: int = 128,
27
- min_new_tokens: int = 0,
28
25
  stop: Optional[Union[str, List[str]]] = None,
29
26
  stop_token_ids: Optional[List[int]] = None,
30
27
  temperature: float = 1.0,
@@ -34,6 +31,7 @@ class SamplingParams:
34
31
  frequency_penalty: float = 0.0,
35
32
  presence_penalty: float = 0.0,
36
33
  repetition_penalty: float = 1.0,
34
+ min_new_tokens: int = 0,
37
35
  spaces_between_special_tokens: bool = True,
38
36
  regex: Optional[str] = None,
39
37
  n: int = 1,