sglang 0.3.5.post2__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +2 -2
- sglang/bench_latency.py +1 -553
- sglang/bench_offline_throughput.py +48 -20
- sglang/bench_one_batch.py +472 -0
- sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
- sglang/bench_serving.py +125 -6
- sglang/check_env.py +3 -6
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +2 -2
- sglang/srt/configs/model_config.py +13 -14
- sglang/srt/constrained/__init__.py +13 -14
- sglang/srt/constrained/base_grammar_backend.py +13 -15
- sglang/srt/constrained/outlines_backend.py +28 -17
- sglang/srt/constrained/outlines_jump_forward.py +13 -15
- sglang/srt/constrained/xgrammar_backend.py +47 -58
- sglang/srt/conversation.py +13 -15
- sglang/srt/hf_transformers_utils.py +13 -15
- sglang/srt/layers/activation.py +16 -13
- sglang/srt/layers/attention/flashinfer_backend.py +106 -54
- sglang/srt/layers/attention/triton_backend.py +9 -7
- sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
- sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
- sglang/srt/layers/custom_op_util.py +25 -0
- sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
- sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +11 -4
- sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
- sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
- sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
- sglang/srt/layers/fused_moe_triton/layer.py +633 -0
- sglang/srt/layers/layernorm.py +17 -15
- sglang/srt/layers/logits_processor.py +23 -25
- sglang/srt/layers/quantization/__init__.py +77 -17
- sglang/srt/layers/radix_attention.py +13 -15
- sglang/srt/layers/rotary_embedding.py +13 -13
- sglang/srt/layers/sampler.py +4 -8
- sglang/srt/layers/torchao_utils.py +2 -0
- sglang/srt/lora/lora.py +13 -14
- sglang/srt/lora/lora_config.py +13 -14
- sglang/srt/lora/lora_manager.py +22 -24
- sglang/srt/managers/data_parallel_controller.py +98 -27
- sglang/srt/managers/detokenizer_manager.py +13 -15
- sglang/srt/managers/io_struct.py +63 -21
- sglang/srt/managers/schedule_batch.py +154 -59
- sglang/srt/managers/schedule_policy.py +18 -16
- sglang/srt/managers/scheduler.py +278 -109
- sglang/srt/managers/session_controller.py +61 -0
- sglang/srt/managers/tokenizer_manager.py +63 -18
- sglang/srt/managers/tp_worker.py +25 -16
- sglang/srt/managers/tp_worker_overlap_thread.py +62 -67
- sglang/srt/metrics/collector.py +13 -15
- sglang/srt/metrics/func_timer.py +13 -15
- sglang/srt/mm_utils.py +13 -14
- sglang/srt/model_executor/cuda_graph_runner.py +63 -25
- sglang/srt/model_executor/forward_batch_info.py +128 -32
- sglang/srt/model_executor/model_runner.py +132 -64
- sglang/srt/model_parallel.py +98 -0
- sglang/srt/models/chatglm.py +15 -16
- sglang/srt/models/commandr.py +15 -16
- sglang/srt/models/dbrx.py +15 -16
- sglang/srt/models/deepseek.py +15 -15
- sglang/srt/models/deepseek_v2.py +162 -59
- sglang/srt/models/exaone.py +14 -15
- sglang/srt/models/gemma.py +14 -14
- sglang/srt/models/gemma2.py +31 -25
- sglang/srt/models/gemma2_reward.py +13 -14
- sglang/srt/models/gpt_bigcode.py +14 -14
- sglang/srt/models/grok.py +15 -15
- sglang/srt/models/internlm2.py +13 -15
- sglang/srt/models/internlm2_reward.py +13 -14
- sglang/srt/models/llama.py +21 -21
- sglang/srt/models/llama_classification.py +13 -14
- sglang/srt/models/llama_reward.py +13 -14
- sglang/srt/models/llava.py +14 -16
- sglang/srt/models/llavavid.py +14 -16
- sglang/srt/models/minicpm.py +13 -15
- sglang/srt/models/minicpm3.py +13 -15
- sglang/srt/models/mistral.py +13 -15
- sglang/srt/models/mixtral.py +15 -15
- sglang/srt/models/mixtral_quant.py +14 -14
- sglang/srt/models/olmo.py +22 -20
- sglang/srt/models/olmoe.py +23 -20
- sglang/srt/models/phi3_small.py +447 -0
- sglang/srt/models/qwen.py +14 -14
- sglang/srt/models/qwen2.py +22 -19
- sglang/srt/models/qwen2_moe.py +17 -18
- sglang/srt/models/qwen2_vl.py +13 -6
- sglang/srt/models/stablelm.py +18 -16
- sglang/srt/models/torch_native_llama.py +107 -93
- sglang/srt/models/xverse.py +13 -14
- sglang/srt/models/xverse_moe.py +15 -16
- sglang/srt/models/yivl.py +13 -15
- sglang/srt/openai_api/adapter.py +19 -17
- sglang/srt/openai_api/protocol.py +14 -16
- sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
- sglang/srt/sampling/sampling_batch_info.py +61 -57
- sglang/srt/sampling/sampling_params.py +14 -16
- sglang/srt/server.py +86 -35
- sglang/srt/server_args.py +96 -80
- sglang/srt/utils.py +266 -68
- sglang/test/few_shot_gsm8k.py +8 -4
- sglang/test/runners.py +38 -20
- sglang/test/srt/sampling/penaltylib/utils.py +23 -21
- sglang/test/test_utils.py +31 -20
- sglang/version.py +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +66 -57
- sglang-0.3.6.post1.dist-info/RECORD +164 -0
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +1 -1
- sglang/srt/layers/fused_moe/__init__.py +0 -1
- sglang-0.3.5.post2.dist-info/RECORD +0 -156
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -1,40 +1,34 @@
|
|
1
1
|
import abc
|
2
2
|
import dataclasses
|
3
|
-
import
|
3
|
+
from typing import List, Set, Type, Union
|
4
4
|
|
5
5
|
import torch
|
6
6
|
|
7
7
|
|
8
8
|
@dataclasses.dataclass
|
9
9
|
class _ReqLike:
|
10
|
-
origin_input_ids:
|
10
|
+
origin_input_ids: List[int]
|
11
11
|
|
12
12
|
|
13
13
|
@dataclasses.dataclass
|
14
14
|
class _BatchLike:
|
15
|
-
reqs:
|
15
|
+
reqs: List[_ReqLike]
|
16
16
|
|
17
17
|
def batch_size(self):
|
18
18
|
return len(self.reqs)
|
19
19
|
|
20
20
|
|
21
21
|
class BatchedPenalizerOrchestrator:
|
22
|
-
batch: _BatchLike
|
23
|
-
device: str
|
24
|
-
vocab_size: int
|
25
|
-
penalizers: typing.Dict[typing.Type["_BatchedPenalizer"], "_BatchedPenalizer"]
|
26
|
-
|
27
22
|
def __init__(
|
28
23
|
self,
|
29
24
|
vocab_size: int,
|
30
25
|
batch: _BatchLike,
|
31
26
|
device: str,
|
32
|
-
Penalizers:
|
27
|
+
Penalizers: Set[Type["_BatchedPenalizer"]],
|
33
28
|
):
|
34
29
|
self.vocab_size = vocab_size
|
35
30
|
self.batch = batch
|
36
31
|
self.device = device
|
37
|
-
|
38
32
|
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
|
39
33
|
|
40
34
|
is_required = False
|
@@ -43,10 +37,12 @@ class BatchedPenalizerOrchestrator:
|
|
43
37
|
is_required |= pen_is_required
|
44
38
|
self.is_required = is_required
|
45
39
|
|
40
|
+
input_ids = [
|
41
|
+
torch.tensor(req.origin_input_ids, dtype=torch.int64, device=self.device)
|
42
|
+
for req in self.reqs()
|
43
|
+
]
|
46
44
|
if self.is_required:
|
47
|
-
self.cumulate_input_tokens(
|
48
|
-
input_ids=[req.origin_input_ids for req in self.reqs()]
|
49
|
-
)
|
45
|
+
self.cumulate_input_tokens(input_ids=input_ids)
|
50
46
|
|
51
47
|
def reqs(self):
|
52
48
|
return self.batch.reqs
|
@@ -54,34 +50,24 @@ class BatchedPenalizerOrchestrator:
|
|
54
50
|
def batch_size(self):
|
55
51
|
return self.batch.batch_size()
|
56
52
|
|
57
|
-
def cumulate_input_tokens(
|
58
|
-
self,
|
59
|
-
input_ids: typing.Union[
|
60
|
-
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
61
|
-
],
|
62
|
-
):
|
53
|
+
def cumulate_input_tokens(self, input_ids: List[torch.Tensor]):
|
63
54
|
"""
|
64
55
|
Feed the input tokens to the penalizers.
|
65
56
|
|
66
57
|
Args:
|
67
|
-
input_ids (
|
58
|
+
input_ids (List[torch.Tensor]): The input tokens.
|
68
59
|
"""
|
69
60
|
token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
|
70
61
|
|
71
62
|
for penalizer in self.penalizers.values():
|
72
63
|
penalizer.cumulate_input_tokens(input_ids=token_ids)
|
73
64
|
|
74
|
-
def cumulate_output_tokens(
|
75
|
-
self,
|
76
|
-
output_ids: typing.Union[
|
77
|
-
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
78
|
-
],
|
79
|
-
):
|
65
|
+
def cumulate_output_tokens(self, output_ids: torch.Tensor):
|
80
66
|
"""
|
81
67
|
Feed the output tokens to the penalizers.
|
82
68
|
|
83
69
|
Args:
|
84
|
-
output_ids (
|
70
|
+
output_ids (torch.Tensor): The output tokens.
|
85
71
|
"""
|
86
72
|
if not self.is_required:
|
87
73
|
return
|
@@ -112,14 +98,14 @@ class BatchedPenalizerOrchestrator:
|
|
112
98
|
|
113
99
|
def filter(
|
114
100
|
self,
|
115
|
-
indices_to_keep:
|
101
|
+
indices_to_keep: List[int],
|
116
102
|
indices_tensor_to_keep: torch.Tensor = None,
|
117
103
|
):
|
118
104
|
"""
|
119
105
|
Filter the penalizers based on the indices to keep in the batch.
|
120
106
|
|
121
107
|
Args:
|
122
|
-
indices_to_keep (
|
108
|
+
indices_to_keep (List[int]): List of indices to keep in the batch.
|
123
109
|
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
|
124
110
|
"""
|
125
111
|
if not self.is_required:
|
@@ -174,32 +160,18 @@ class _TokenIDs:
|
|
174
160
|
|
175
161
|
Attributes:
|
176
162
|
orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
|
177
|
-
token_ids (
|
163
|
+
token_ids (Union[torch.Tensor, List[torch.Tensor]]): The token IDs.
|
178
164
|
cached_counts (torch.Tensor): The cached occurrence count tensor.
|
179
165
|
"""
|
180
166
|
|
181
|
-
orchestrator: BatchedPenalizerOrchestrator
|
182
|
-
token_ids: typing.Union[torch.Tensor, typing.List[torch.Tensor]]
|
183
|
-
cached_counts: torch.Tensor = None
|
184
|
-
|
185
167
|
def __init__(
|
186
168
|
self,
|
187
169
|
orchestrator: BatchedPenalizerOrchestrator,
|
188
|
-
token_ids:
|
189
|
-
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
190
|
-
],
|
170
|
+
token_ids: Union[torch.Tensor, List[torch.Tensor]],
|
191
171
|
):
|
192
172
|
self.orchestrator = orchestrator
|
193
|
-
|
194
|
-
if not isinstance(token_ids[0], torch.Tensor):
|
195
|
-
token_ids = [
|
196
|
-
torch.tensor(
|
197
|
-
data=ids, dtype=torch.int64, device=self.orchestrator.device
|
198
|
-
)
|
199
|
-
for ids in token_ids
|
200
|
-
]
|
201
|
-
|
202
173
|
self.token_ids = token_ids
|
174
|
+
self.cached_counts = None
|
203
175
|
|
204
176
|
def occurrence_count(self) -> torch.Tensor:
|
205
177
|
"""
|
@@ -213,30 +185,34 @@ class _TokenIDs:
|
|
213
185
|
|
214
186
|
token_ids = self.token_ids
|
215
187
|
|
216
|
-
if isinstance(token_ids,
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
188
|
+
if isinstance(token_ids, list):
|
189
|
+
# TODO: optimize this part
|
190
|
+
padded_token_ids = torch.nn.utils.rnn.pad_sequence(
|
191
|
+
sequences=token_ids,
|
192
|
+
batch_first=True,
|
193
|
+
padding_value=self.orchestrator.vocab_size,
|
194
|
+
)
|
195
|
+
self.cached_counts = torch.zeros(
|
196
|
+
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
|
197
|
+
dtype=torch.int64,
|
198
|
+
device=self.orchestrator.device,
|
199
|
+
).scatter_add_(
|
200
|
+
dim=1,
|
201
|
+
index=padded_token_ids,
|
202
|
+
src=torch.ones_like(padded_token_ids),
|
203
|
+
)[
|
204
|
+
:, : self.orchestrator.vocab_size
|
205
|
+
]
|
206
|
+
else:
|
207
|
+
# TODO: optimize this part. We do not need to create this big tensor every time.
|
208
|
+
# We can directly apply the results on the logits.
|
209
|
+
self.cached_counts = torch.zeros(
|
210
|
+
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size),
|
211
|
+
device=self.orchestrator.device,
|
212
|
+
)
|
213
|
+
self.cached_counts[
|
214
|
+
torch.arange(len(token_ids), device=self.orchestrator.device), token_ids
|
215
|
+
] = 1
|
240
216
|
|
241
217
|
return self.cached_counts
|
242
218
|
|
@@ -246,11 +222,9 @@ class _BatchedPenalizer(abc.ABC):
|
|
246
222
|
An abstract class for a batched penalizer.
|
247
223
|
"""
|
248
224
|
|
249
|
-
orchestrator: BatchedPenalizerOrchestrator
|
250
|
-
_is_prepared: bool = False
|
251
|
-
|
252
225
|
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
253
226
|
self.orchestrator = orchestrator
|
227
|
+
self._is_prepared = False
|
254
228
|
|
255
229
|
def is_prepared(self) -> bool:
|
256
230
|
return self._is_prepared
|
@@ -293,9 +267,7 @@ class _BatchedPenalizer(abc.ABC):
|
|
293
267
|
|
294
268
|
return self._apply(logits=logits)
|
295
269
|
|
296
|
-
def filter(
|
297
|
-
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
298
|
-
):
|
270
|
+
def filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
299
271
|
if not self.is_prepared():
|
300
272
|
return
|
301
273
|
|
@@ -360,9 +332,7 @@ class _BatchedPenalizer(abc.ABC):
|
|
360
332
|
pass
|
361
333
|
|
362
334
|
@abc.abstractmethod
|
363
|
-
def _filter(
|
364
|
-
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
365
|
-
):
|
335
|
+
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
366
336
|
"""
|
367
337
|
Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
|
368
338
|
"""
|
@@ -1,8 +1,8 @@
|
|
1
|
-
import
|
1
|
+
from typing import List
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from
|
5
|
+
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
6
6
|
|
7
7
|
|
8
8
|
class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
@@ -44,9 +44,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
|
44
44
|
)
|
45
45
|
|
46
46
|
def _teardown(self):
|
47
|
-
del self.frequency_penalties
|
48
|
-
del self.cumulated_frequency_penalties
|
49
|
-
|
50
47
|
self.frequency_penalties = None
|
51
48
|
self.cumulated_frequency_penalties = None
|
52
49
|
|
@@ -62,9 +59,7 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
|
62
59
|
logits -= self.cumulated_frequency_penalties
|
63
60
|
return logits
|
64
61
|
|
65
|
-
def _filter(
|
66
|
-
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
67
|
-
):
|
62
|
+
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
68
63
|
self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]
|
69
64
|
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
|
70
65
|
indices_tensor_to_keep
|
@@ -1,8 +1,8 @@
|
|
1
|
-
import
|
1
|
+
from typing import List
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from
|
5
|
+
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
6
6
|
|
7
7
|
|
8
8
|
class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
@@ -70,10 +70,6 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
|
70
70
|
)
|
71
71
|
|
72
72
|
def _teardown(self):
|
73
|
-
del self.min_new_tokens
|
74
|
-
del self.stop_token_penalties
|
75
|
-
del self.len_output_tokens
|
76
|
-
|
77
73
|
self.min_new_tokens = None
|
78
74
|
self.stop_token_penalties = None
|
79
75
|
self.len_output_tokens = None
|
@@ -89,9 +85,7 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
|
89
85
|
logits[mask] += self.stop_token_penalties[mask]
|
90
86
|
return logits
|
91
87
|
|
92
|
-
def _filter(
|
93
|
-
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
94
|
-
):
|
88
|
+
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
95
89
|
self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep]
|
96
90
|
self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep]
|
97
91
|
self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep]
|
@@ -1,8 +1,8 @@
|
|
1
|
-
import
|
1
|
+
from typing import List
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from
|
5
|
+
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
6
6
|
|
7
7
|
|
8
8
|
class BatchedPresencePenalizer(_BatchedPenalizer):
|
@@ -44,9 +44,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
|
|
44
44
|
)
|
45
45
|
|
46
46
|
def _teardown(self):
|
47
|
-
del self.presence_penalties
|
48
|
-
del self.cumulated_presence_penalties
|
49
|
-
|
50
47
|
self.presence_penalties = None
|
51
48
|
self.cumulated_presence_penalties = None
|
52
49
|
|
@@ -61,9 +58,7 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
|
|
61
58
|
logits -= self.cumulated_presence_penalties
|
62
59
|
return logits
|
63
60
|
|
64
|
-
def _filter(
|
65
|
-
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
66
|
-
):
|
61
|
+
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
67
62
|
self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
|
68
63
|
self.cumulated_presence_penalties = self.cumulated_presence_penalties[
|
69
64
|
indices_tensor_to_keep
|
@@ -1,8 +1,8 @@
|
|
1
|
-
import
|
1
|
+
from typing import List
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from
|
5
|
+
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
6
6
|
|
7
7
|
|
8
8
|
class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
@@ -44,9 +44,6 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
|
44
44
|
)
|
45
45
|
|
46
46
|
def _teardown(self):
|
47
|
-
del self.repetition_penalties
|
48
|
-
del self.cumulated_repetition_penalties
|
49
|
-
|
50
47
|
self.repetition_penalties = None
|
51
48
|
self.cumulated_repetition_penalties = None
|
52
49
|
|
@@ -65,9 +62,7 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
|
65
62
|
logits * self.cumulated_repetition_penalties,
|
66
63
|
)
|
67
64
|
|
68
|
-
def _filter(
|
69
|
-
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
70
|
-
):
|
65
|
+
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
71
66
|
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
|
72
67
|
self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
|
73
68
|
indices_tensor_to_keep
|
@@ -1,12 +1,17 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import dataclasses
|
4
|
-
|
4
|
+
import logging
|
5
|
+
import threading
|
6
|
+
from typing import TYPE_CHECKING, Callable, List, Optional
|
5
7
|
|
6
8
|
import torch
|
7
9
|
|
8
10
|
import sglang.srt.sampling.penaltylib as penaltylib
|
9
11
|
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
10
15
|
if TYPE_CHECKING:
|
11
16
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
12
17
|
|
@@ -27,10 +32,11 @@ class SamplingBatchInfo:
|
|
27
32
|
|
28
33
|
# Bias Tensors
|
29
34
|
vocab_size: int
|
35
|
+
grammars: Optional[List] = None
|
36
|
+
sampling_info_done: Optional[threading.Event] = None
|
30
37
|
logit_bias: torch.Tensor = None
|
31
38
|
vocab_mask: Optional[torch.Tensor] = None
|
32
|
-
|
33
|
-
grammars: Optional[List] = None
|
39
|
+
apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
34
40
|
|
35
41
|
# Penalizer
|
36
42
|
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
@@ -42,10 +48,7 @@ class SamplingBatchInfo:
|
|
42
48
|
|
43
49
|
@classmethod
|
44
50
|
def from_schedule_batch(
|
45
|
-
cls,
|
46
|
-
batch: ScheduleBatch,
|
47
|
-
vocab_size: int,
|
48
|
-
disable_penalizer: bool,
|
51
|
+
cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
|
49
52
|
):
|
50
53
|
reqs = batch.reqs
|
51
54
|
device = batch.device
|
@@ -73,12 +76,39 @@ class SamplingBatchInfo:
|
|
73
76
|
top_ks=top_ks,
|
74
77
|
min_ps=min_ps,
|
75
78
|
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
76
|
-
is_all_greedy=
|
79
|
+
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
77
80
|
vocab_size=vocab_size,
|
78
81
|
device=device,
|
79
82
|
)
|
80
83
|
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
81
84
|
|
85
|
+
if enable_overlap_schedule:
|
86
|
+
# TODO (lianmin): Some penalizers such as frequency and presence depend on model outputs,
|
87
|
+
# so it is kind of tricky to make it work with overlap scheduler.
|
88
|
+
# It requires correcly updating the penalty logits before the sampling and syncing the events.
|
89
|
+
# We will support them later.
|
90
|
+
penalizers = {
|
91
|
+
penaltylib.BatchedMinNewTokensPenalizer,
|
92
|
+
}
|
93
|
+
if (
|
94
|
+
any(req.sampling_params.frequency_penalty != 0.0 for req in reqs)
|
95
|
+
or any(req.sampling_params.presence_penalty != 0.0 for req in reqs)
|
96
|
+
or any(req.sampling_params.repetition_penalty != 1.0 for req in reqs)
|
97
|
+
):
|
98
|
+
logger.warning(
|
99
|
+
"frequency_penalty, presence_penalty, and repetition_penalty are not supported "
|
100
|
+
"when using the default overlap scheduler. They will be ignored. "
|
101
|
+
"Please add `--disable-overlap` when launching the server if you need these features. "
|
102
|
+
"The speed will be slower in that case."
|
103
|
+
)
|
104
|
+
else:
|
105
|
+
penalizers = {
|
106
|
+
penaltylib.BatchedFrequencyPenalizer,
|
107
|
+
penaltylib.BatchedMinNewTokensPenalizer,
|
108
|
+
penaltylib.BatchedPresencePenalizer,
|
109
|
+
penaltylib.BatchedRepetitionPenalizer,
|
110
|
+
}
|
111
|
+
|
82
112
|
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
83
113
|
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
84
114
|
# should not add hefty computation overhead other than simple checks.
|
@@ -86,20 +116,12 @@ class SamplingBatchInfo:
|
|
86
116
|
# While we choose not to even create the class instances if they are not required, this
|
87
117
|
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
88
118
|
# handle {filter_batch()} and {merge_batch()} cases as well.
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
device=batch.device,
|
96
|
-
Penalizers={
|
97
|
-
penaltylib.BatchedFrequencyPenalizer,
|
98
|
-
penaltylib.BatchedMinNewTokensPenalizer,
|
99
|
-
penaltylib.BatchedPresencePenalizer,
|
100
|
-
penaltylib.BatchedRepetitionPenalizer,
|
101
|
-
},
|
102
|
-
)
|
119
|
+
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
120
|
+
vocab_size=vocab_size,
|
121
|
+
batch=batch,
|
122
|
+
device=batch.device,
|
123
|
+
Penalizers=penalizers,
|
124
|
+
)
|
103
125
|
|
104
126
|
# Handle logit bias but only allocate when needed
|
105
127
|
ret.logit_bias = None
|
@@ -110,9 +132,6 @@ class SamplingBatchInfo:
|
|
110
132
|
return len(self.temperatures)
|
111
133
|
|
112
134
|
def update_penalties(self):
|
113
|
-
if not self.penalizer_orchestrator:
|
114
|
-
return
|
115
|
-
|
116
135
|
self.scaling_penalties = None
|
117
136
|
self.linear_penalties = None
|
118
137
|
|
@@ -133,23 +152,31 @@ class SamplingBatchInfo:
|
|
133
152
|
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
134
153
|
|
135
154
|
def update_regex_vocab_mask(self):
|
136
|
-
if not self.grammars
|
155
|
+
if not self.grammars:
|
137
156
|
self.vocab_mask = None
|
157
|
+
self.apply_mask = None
|
138
158
|
return
|
139
159
|
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
160
|
+
# find a grammar from the list
|
161
|
+
grammar = next(grammar for grammar in self.grammars if grammar)
|
162
|
+
|
163
|
+
# maybe we can reuse the existing mask?
|
164
|
+
self.vocab_mask = grammar.allocate_vocab_mask(
|
165
|
+
vocab_size=self.vocab_size,
|
166
|
+
batch_size=len(self.temperatures),
|
144
167
|
device=self.device,
|
145
168
|
)
|
169
|
+
self.apply_mask = type(grammar).apply_vocab_mask # force to use static method
|
170
|
+
|
146
171
|
for i, grammar in enumerate(self.grammars):
|
147
172
|
if grammar is not None:
|
148
|
-
|
173
|
+
try:
|
174
|
+
grammar.fill_vocab_mask(self.vocab_mask, i)
|
175
|
+
except RuntimeError:
|
176
|
+
continue
|
149
177
|
|
150
178
|
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
151
|
-
|
152
|
-
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
179
|
+
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
153
180
|
|
154
181
|
for item in [
|
155
182
|
"temperatures",
|
@@ -188,8 +215,7 @@ class SamplingBatchInfo:
|
|
188
215
|
return None
|
189
216
|
|
190
217
|
def merge_batch(self, other: "SamplingBatchInfo"):
|
191
|
-
|
192
|
-
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
218
|
+
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
193
219
|
|
194
220
|
for item in [
|
195
221
|
"temperatures",
|
@@ -205,25 +231,3 @@ class SamplingBatchInfo:
|
|
205
231
|
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
206
232
|
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
207
233
|
)
|
208
|
-
|
209
|
-
def copy(self):
|
210
|
-
return SamplingBatchInfo(
|
211
|
-
temperatures=self.temperatures,
|
212
|
-
top_ps=self.top_ps,
|
213
|
-
top_ks=self.top_ks,
|
214
|
-
min_ps=self.min_ps,
|
215
|
-
is_all_greedy=self.is_all_greedy,
|
216
|
-
need_min_p_sampling=self.need_min_p_sampling,
|
217
|
-
vocab_size=self.vocab_size,
|
218
|
-
device=self.device,
|
219
|
-
)
|
220
|
-
|
221
|
-
def to(self, device: str):
|
222
|
-
for item in [
|
223
|
-
"temperatures",
|
224
|
-
"top_ps",
|
225
|
-
"top_ks",
|
226
|
-
"min_ps",
|
227
|
-
]:
|
228
|
-
value = getattr(self, item)
|
229
|
-
setattr(self, item, value.to(device, non_blocking=True))
|
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""Sampling parameters for text generation."""
|
17
15
|
|
18
16
|
from typing import List, Optional, Union
|
@@ -24,7 +22,6 @@ class SamplingParams:
|
|
24
22
|
def __init__(
|
25
23
|
self,
|
26
24
|
max_new_tokens: int = 128,
|
27
|
-
min_new_tokens: int = 0,
|
28
25
|
stop: Optional[Union[str, List[str]]] = None,
|
29
26
|
stop_token_ids: Optional[List[int]] = None,
|
30
27
|
temperature: float = 1.0,
|
@@ -34,6 +31,7 @@ class SamplingParams:
|
|
34
31
|
frequency_penalty: float = 0.0,
|
35
32
|
presence_penalty: float = 0.0,
|
36
33
|
repetition_penalty: float = 1.0,
|
34
|
+
min_new_tokens: int = 0,
|
37
35
|
spaces_between_special_tokens: bool = True,
|
38
36
|
regex: Optional[str] = None,
|
39
37
|
n: int = 1,
|