sglang 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -0
- sglang/api.py +10 -2
- sglang/bench_latency.py +151 -40
- sglang/bench_serving.py +46 -22
- sglang/check_env.py +24 -2
- sglang/global_config.py +0 -1
- sglang/lang/backend/base_backend.py +3 -1
- sglang/lang/backend/openai.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +46 -29
- sglang/lang/choices.py +164 -0
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +6 -13
- sglang/lang/ir.py +14 -5
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/layers/activation.py +33 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +6 -1
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +6 -1
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +4 -7
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +174 -380
- sglang/srt/managers/tokenizer_manager.py +197 -112
- sglang/srt/managers/tp_worker.py +299 -364
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +10 -15
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +27 -12
- sglang/srt/model_executor/forward_batch_info.py +319 -0
- sglang/srt/model_executor/model_runner.py +30 -47
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +1 -1
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -2
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/internlm2.py +3 -8
- sglang/srt/models/llama2.py +5 -5
- sglang/srt/models/llama_classification.py +1 -1
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/llava.py +1 -2
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -1
- sglang/srt/models/mixtral.py +1 -1
- sglang/srt/models/mixtral_quant.py +1 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -12
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +189 -39
- sglang/srt/openai_api/protocol.py +43 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -4
- sglang/srt/server.py +93 -21
- sglang/srt/server_args.py +30 -19
- sglang/srt/utils.py +31 -13
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +63 -63
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +4 -2
- sglang/test/test_utils.py +21 -3
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/METADATA +50 -31
- sglang-0.2.12.dist-info/RECORD +112 -0
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang-0.2.10.dist-info/RECORD +0 -100
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
@@ -78,6 +78,10 @@ class UsageInfo(BaseModel):
|
|
78
78
|
completion_tokens: Optional[int] = 0
|
79
79
|
|
80
80
|
|
81
|
+
class StreamOptions(BaseModel):
|
82
|
+
include_usage: Optional[bool] = False
|
83
|
+
|
84
|
+
|
81
85
|
class FileRequest(BaseModel):
|
82
86
|
# https://platform.openai.com/docs/api-reference/files/create
|
83
87
|
file: bytes # The File object (not file name) to be uploaded
|
@@ -95,6 +99,12 @@ class FileResponse(BaseModel):
|
|
95
99
|
purpose: str
|
96
100
|
|
97
101
|
|
102
|
+
class FileDeleteResponse(BaseModel):
|
103
|
+
id: str
|
104
|
+
object: str = "file"
|
105
|
+
deleted: bool
|
106
|
+
|
107
|
+
|
98
108
|
class BatchRequest(BaseModel):
|
99
109
|
input_file_id: (
|
100
110
|
str # The ID of an uploaded file that contains requests for the new batch
|
@@ -143,6 +153,7 @@ class CompletionRequest(BaseModel):
|
|
143
153
|
seed: Optional[int] = None
|
144
154
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
145
155
|
stream: Optional[bool] = False
|
156
|
+
stream_options: Optional[StreamOptions] = None
|
146
157
|
suffix: Optional[str] = None
|
147
158
|
temperature: Optional[float] = 1.0
|
148
159
|
top_p: Optional[float] = 1.0
|
@@ -151,6 +162,9 @@ class CompletionRequest(BaseModel):
|
|
151
162
|
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
152
163
|
regex: Optional[str] = None
|
153
164
|
ignore_eos: Optional[bool] = False
|
165
|
+
min_tokens: Optional[int] = 0
|
166
|
+
repetition_penalty: Optional[float] = 1.0
|
167
|
+
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
154
168
|
|
155
169
|
|
156
170
|
class CompletionResponseChoice(BaseModel):
|
@@ -182,7 +196,7 @@ class CompletionStreamResponse(BaseModel):
|
|
182
196
|
created: int = Field(default_factory=lambda: int(time.time()))
|
183
197
|
model: str
|
184
198
|
choices: List[CompletionResponseStreamChoice]
|
185
|
-
usage: UsageInfo
|
199
|
+
usage: Optional[UsageInfo] = None
|
186
200
|
|
187
201
|
|
188
202
|
class ChatCompletionMessageGenericParam(BaseModel):
|
@@ -241,12 +255,16 @@ class ChatCompletionRequest(BaseModel):
|
|
241
255
|
seed: Optional[int] = None
|
242
256
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
243
257
|
stream: Optional[bool] = False
|
258
|
+
stream_options: Optional[StreamOptions] = None
|
244
259
|
temperature: Optional[float] = 0.7
|
245
260
|
top_p: Optional[float] = 1.0
|
246
261
|
user: Optional[str] = None
|
247
262
|
|
248
263
|
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
249
264
|
regex: Optional[str] = None
|
265
|
+
min_tokens: Optional[int] = 0
|
266
|
+
repetition_penalty: Optional[float] = 1.0
|
267
|
+
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
250
268
|
|
251
269
|
|
252
270
|
class ChatMessage(BaseModel):
|
@@ -288,3 +306,27 @@ class ChatCompletionStreamResponse(BaseModel):
|
|
288
306
|
created: int = Field(default_factory=lambda: int(time.time()))
|
289
307
|
model: str
|
290
308
|
choices: List[ChatCompletionResponseStreamChoice]
|
309
|
+
usage: Optional[UsageInfo] = None
|
310
|
+
|
311
|
+
|
312
|
+
class EmbeddingRequest(BaseModel):
|
313
|
+
# Ordered by official OpenAI API documentation
|
314
|
+
# https://platform.openai.com/docs/api-reference/embeddings/create
|
315
|
+
input: Union[List[int], List[List[int]], str, List[str]]
|
316
|
+
model: str
|
317
|
+
encoding_format: str = "float"
|
318
|
+
dimensions: int = None
|
319
|
+
user: Optional[str] = None
|
320
|
+
|
321
|
+
|
322
|
+
class EmbeddingObject(BaseModel):
|
323
|
+
embedding: List[float]
|
324
|
+
index: int
|
325
|
+
object: str = "embedding"
|
326
|
+
|
327
|
+
|
328
|
+
class EmbeddingResponse(BaseModel):
|
329
|
+
data: List[EmbeddingObject]
|
330
|
+
model: str
|
331
|
+
object: str = "list"
|
332
|
+
usage: Optional[UsageInfo] = None
|
@@ -0,0 +1,13 @@
|
|
1
|
+
from .orchestrator import BatchedPenalizerOrchestrator
|
2
|
+
from .penalizers.frequency_penalty import BatchedFrequencyPenalizer
|
3
|
+
from .penalizers.min_new_tokens import BatchedMinNewTokensPenalizer
|
4
|
+
from .penalizers.presence_penalty import BatchedPresencePenalizer
|
5
|
+
from .penalizers.repetition_penalty import BatchedRepetitionPenalizer
|
6
|
+
|
7
|
+
__all__ = [
|
8
|
+
"BatchedFrequencyPenalizer",
|
9
|
+
"BatchedMinNewTokensPenalizer",
|
10
|
+
"BatchedPresencePenalizer",
|
11
|
+
"BatchedRepetitionPenalizer",
|
12
|
+
"BatchedPenalizerOrchestrator",
|
13
|
+
]
|
@@ -0,0 +1,357 @@
|
|
1
|
+
import abc
|
2
|
+
import dataclasses
|
3
|
+
import typing
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
|
8
|
+
@dataclasses.dataclass
|
9
|
+
class _ReqLike:
|
10
|
+
origin_input_ids: typing.Union[torch.Tensor, typing.List[int]]
|
11
|
+
|
12
|
+
|
13
|
+
@dataclasses.dataclass
|
14
|
+
class _BatchLike:
|
15
|
+
reqs: typing.List[_ReqLike]
|
16
|
+
|
17
|
+
def batch_size(self):
|
18
|
+
return len(self.reqs)
|
19
|
+
|
20
|
+
|
21
|
+
class BatchedPenalizerOrchestrator:
|
22
|
+
batch: _BatchLike
|
23
|
+
device: str
|
24
|
+
vocab_size: int
|
25
|
+
penalizers: typing.Dict[typing.Type["_BatchedPenalizer"], "_BatchedPenalizer"]
|
26
|
+
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
vocab_size: int,
|
30
|
+
batch: _BatchLike,
|
31
|
+
device: str,
|
32
|
+
Penalizers: typing.Set[typing.Type["_BatchedPenalizer"]],
|
33
|
+
):
|
34
|
+
self.vocab_size = vocab_size
|
35
|
+
self.batch = batch
|
36
|
+
self.device = device
|
37
|
+
|
38
|
+
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
|
39
|
+
|
40
|
+
for penalizer in self.penalizers.values():
|
41
|
+
penalizer.prepare_if_required()
|
42
|
+
|
43
|
+
self.cumulate_input_tokens(
|
44
|
+
input_ids=[req.origin_input_ids for req in self.reqs()]
|
45
|
+
)
|
46
|
+
|
47
|
+
def reqs(self):
|
48
|
+
return self.batch.reqs
|
49
|
+
|
50
|
+
def batch_size(self):
|
51
|
+
return self.batch.batch_size()
|
52
|
+
|
53
|
+
def cumulate_input_tokens(
|
54
|
+
self,
|
55
|
+
input_ids: typing.Union[
|
56
|
+
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
57
|
+
],
|
58
|
+
):
|
59
|
+
"""
|
60
|
+
Feed the input tokens to the penalizers.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
input_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The input tokens.
|
64
|
+
"""
|
65
|
+
token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
|
66
|
+
|
67
|
+
for penalizer in self.penalizers.values():
|
68
|
+
penalizer.cumulate_input_tokens(input_ids=token_ids)
|
69
|
+
|
70
|
+
def cumulate_output_tokens(
|
71
|
+
self,
|
72
|
+
output_ids: typing.Union[
|
73
|
+
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
74
|
+
],
|
75
|
+
):
|
76
|
+
"""
|
77
|
+
Feed the output tokens to the penalizers.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.
|
81
|
+
"""
|
82
|
+
token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)
|
83
|
+
|
84
|
+
for penalizer in self.penalizers.values():
|
85
|
+
penalizer.cumulate_output_tokens(output_ids=token_ids)
|
86
|
+
|
87
|
+
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
88
|
+
"""
|
89
|
+
Apply the penalizers to the logits.
|
90
|
+
Note that it may apply the penalizers in-place.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
logits (torch.Tensor): The logits to apply the penalizers to.
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
torch.Tensor: The logits after applying the penalizers.
|
97
|
+
"""
|
98
|
+
for penalizer in self.penalizers.values():
|
99
|
+
logits = penalizer.apply(logits)
|
100
|
+
|
101
|
+
return logits
|
102
|
+
|
103
|
+
def filter(
|
104
|
+
self,
|
105
|
+
indices_to_keep: typing.List[int],
|
106
|
+
indices_tensor_to_keep: torch.Tensor = None,
|
107
|
+
):
|
108
|
+
"""
|
109
|
+
Filter the penalizers based on the indices to keep in the batch.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
indices_to_keep (typing.List[int]): List of indices to keep in the batch.
|
113
|
+
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
|
114
|
+
"""
|
115
|
+
empty_indices = len(indices_to_keep) == 0
|
116
|
+
|
117
|
+
for penalizer in self.penalizers.values():
|
118
|
+
if not penalizer.is_required() or empty_indices:
|
119
|
+
penalizer.teardown()
|
120
|
+
else:
|
121
|
+
# create tensor index only when it's needed
|
122
|
+
if indices_tensor_to_keep is None:
|
123
|
+
indices_tensor_to_keep = torch.tensor(
|
124
|
+
indices_to_keep, dtype=torch.int32, device=self.device
|
125
|
+
)
|
126
|
+
|
127
|
+
penalizer.filter(
|
128
|
+
indices_to_keep=indices_to_keep,
|
129
|
+
indices_tensor_to_keep=indices_tensor_to_keep,
|
130
|
+
)
|
131
|
+
|
132
|
+
def merge(self, their: "BatchedPenalizerOrchestrator"):
|
133
|
+
"""
|
134
|
+
Merge the penalizers of another orchestrator into this one.
|
135
|
+
|
136
|
+
Note that this function **must** be called _before_ self.batch.reqs is updated (filtered).
|
137
|
+
Each unprepared penalizers would have to be prepared (creating tensors, etc.) first before merging.
|
138
|
+
This step requires the original batch.reqs, before it gets merged with other batch.reqs.
|
139
|
+
|
140
|
+
Args:
|
141
|
+
their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.
|
142
|
+
"""
|
143
|
+
if self.vocab_size != their.vocab_size:
|
144
|
+
raise ValueError(
|
145
|
+
f"vocab_size mismatch: {self.vocab_size} != {their.vocab_size}"
|
146
|
+
)
|
147
|
+
|
148
|
+
for Penalizer, their_penalizer in their.penalizers.items():
|
149
|
+
if Penalizer not in self.penalizers:
|
150
|
+
raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers")
|
151
|
+
|
152
|
+
self.penalizers[Penalizer].merge(their_penalizer)
|
153
|
+
|
154
|
+
|
155
|
+
class _TokenIDs:
|
156
|
+
"""
|
157
|
+
A class that wraps token IDs to provide additional utility functions to penalizers.
|
158
|
+
|
159
|
+
Attributes:
|
160
|
+
orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
|
161
|
+
token_ids (typing.Union[torch.Tensor, typing.List[torch.Tensor]]): The token IDs.
|
162
|
+
cached_counts (torch.Tensor): The cached occurrence count tensor.
|
163
|
+
"""
|
164
|
+
|
165
|
+
orchestrator: BatchedPenalizerOrchestrator
|
166
|
+
token_ids: typing.Union[torch.Tensor, typing.List[torch.Tensor]]
|
167
|
+
cached_counts: torch.Tensor = None
|
168
|
+
|
169
|
+
def __init__(
|
170
|
+
self,
|
171
|
+
orchestrator: BatchedPenalizerOrchestrator,
|
172
|
+
token_ids: typing.Union[
|
173
|
+
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
174
|
+
],
|
175
|
+
):
|
176
|
+
self.orchestrator = orchestrator
|
177
|
+
|
178
|
+
if not isinstance(token_ids[0], torch.Tensor):
|
179
|
+
token_ids = [
|
180
|
+
torch.tensor(
|
181
|
+
data=ids, dtype=torch.int64, device=self.orchestrator.device
|
182
|
+
)
|
183
|
+
for ids in token_ids
|
184
|
+
]
|
185
|
+
|
186
|
+
self.token_ids = token_ids
|
187
|
+
|
188
|
+
def occurrence_count(self) -> torch.Tensor:
|
189
|
+
"""
|
190
|
+
Returns a tensor of shape (batch_size, vocab_size) where each element is the number of times the corresponding token appears in the batch.
|
191
|
+
|
192
|
+
Returns:
|
193
|
+
torch.Tensor: The occurrence count tensor.
|
194
|
+
"""
|
195
|
+
if self.cached_counts is not None:
|
196
|
+
return self.cached_counts
|
197
|
+
|
198
|
+
token_ids = self.token_ids
|
199
|
+
|
200
|
+
if isinstance(token_ids, torch.Tensor):
|
201
|
+
token_ids = token_ids.unsqueeze(1)
|
202
|
+
|
203
|
+
# needs to be long to be used as index in scatter_add
|
204
|
+
if token_ids.dtype != torch.int64:
|
205
|
+
token_ids = token_ids.to(torch.int64)
|
206
|
+
|
207
|
+
padded_token_ids = torch.nn.utils.rnn.pad_sequence(
|
208
|
+
sequences=token_ids,
|
209
|
+
batch_first=True,
|
210
|
+
padding_value=self.orchestrator.vocab_size,
|
211
|
+
)
|
212
|
+
|
213
|
+
self.cached_counts = torch.zeros(
|
214
|
+
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
|
215
|
+
dtype=torch.int64,
|
216
|
+
device=self.orchestrator.device,
|
217
|
+
).scatter_add_(
|
218
|
+
dim=1,
|
219
|
+
index=padded_token_ids,
|
220
|
+
src=torch.ones_like(padded_token_ids),
|
221
|
+
)[
|
222
|
+
:, : self.orchestrator.vocab_size
|
223
|
+
]
|
224
|
+
|
225
|
+
return self.cached_counts
|
226
|
+
|
227
|
+
|
228
|
+
class _BatchedPenalizer(abc.ABC):
|
229
|
+
"""
|
230
|
+
An abstract class for a batched penalizer.
|
231
|
+
"""
|
232
|
+
|
233
|
+
orchestrator: BatchedPenalizerOrchestrator
|
234
|
+
_is_prepared: bool = False
|
235
|
+
|
236
|
+
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
237
|
+
self.orchestrator = orchestrator
|
238
|
+
|
239
|
+
def is_prepared(self) -> bool:
|
240
|
+
return self._is_prepared
|
241
|
+
|
242
|
+
def is_required(self) -> bool:
|
243
|
+
return self._is_required()
|
244
|
+
|
245
|
+
def prepare(self):
|
246
|
+
if not self.is_prepared():
|
247
|
+
self._prepare()
|
248
|
+
self._is_prepared = True
|
249
|
+
|
250
|
+
def prepare_if_required(self):
|
251
|
+
if self.is_required():
|
252
|
+
self.prepare()
|
253
|
+
|
254
|
+
def teardown(self):
|
255
|
+
if self.is_prepared():
|
256
|
+
self._teardown()
|
257
|
+
self._is_prepared = False
|
258
|
+
|
259
|
+
def cumulate_input_tokens(self, input_ids: _TokenIDs):
|
260
|
+
if not self.is_prepared():
|
261
|
+
return
|
262
|
+
|
263
|
+
self._cumulate_input_tokens(input_ids=input_ids)
|
264
|
+
|
265
|
+
def cumulate_output_tokens(self, output_ids: _TokenIDs):
|
266
|
+
if not self.is_prepared():
|
267
|
+
return
|
268
|
+
|
269
|
+
self._cumulate_output_tokens(output_ids=output_ids)
|
270
|
+
|
271
|
+
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
272
|
+
if not self.is_prepared():
|
273
|
+
return logits
|
274
|
+
|
275
|
+
return self._apply(logits=logits)
|
276
|
+
|
277
|
+
def filter(
|
278
|
+
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
279
|
+
):
|
280
|
+
if not self.is_prepared():
|
281
|
+
return
|
282
|
+
|
283
|
+
self._filter(
|
284
|
+
indices_to_keep=indices_to_keep,
|
285
|
+
indices_tensor_to_keep=indices_tensor_to_keep,
|
286
|
+
)
|
287
|
+
|
288
|
+
def merge(self, their: "_BatchedPenalizer"):
|
289
|
+
if not self.is_prepared() and not their.is_prepared():
|
290
|
+
return
|
291
|
+
|
292
|
+
self.prepare()
|
293
|
+
their.prepare()
|
294
|
+
self._merge(their)
|
295
|
+
|
296
|
+
@abc.abstractmethod
|
297
|
+
def _is_required(self) -> bool:
|
298
|
+
"""
|
299
|
+
Check if the penalizer is required to be prepared.
|
300
|
+
"""
|
301
|
+
pass
|
302
|
+
|
303
|
+
@abc.abstractmethod
|
304
|
+
def _prepare(self):
|
305
|
+
"""
|
306
|
+
Prepare the penalizer.
|
307
|
+
Usually, this is where the penalizer initializes its tensors.
|
308
|
+
"""
|
309
|
+
pass
|
310
|
+
|
311
|
+
@abc.abstractmethod
|
312
|
+
def _teardown(self):
|
313
|
+
"""
|
314
|
+
Tear down the penalizer.
|
315
|
+
Usually, this is where the penalizer frees its tensors.
|
316
|
+
"""
|
317
|
+
pass
|
318
|
+
|
319
|
+
@abc.abstractmethod
|
320
|
+
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
321
|
+
"""
|
322
|
+
Cumulate the input tokens.
|
323
|
+
Orchestrator will call this function to feed the input tokens to the penalizer.
|
324
|
+
"""
|
325
|
+
pass
|
326
|
+
|
327
|
+
@abc.abstractmethod
|
328
|
+
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
329
|
+
"""
|
330
|
+
Cumulate the output tokens.
|
331
|
+
Orchestrator will call this function to feed the output tokens to the penalizer.
|
332
|
+
"""
|
333
|
+
pass
|
334
|
+
|
335
|
+
@abc.abstractmethod
|
336
|
+
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
337
|
+
"""
|
338
|
+
Apply the penalizer to the logits.
|
339
|
+
Penalizers can modify the logits in-place if needed.
|
340
|
+
"""
|
341
|
+
pass
|
342
|
+
|
343
|
+
@abc.abstractmethod
|
344
|
+
def _filter(
|
345
|
+
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
346
|
+
):
|
347
|
+
"""
|
348
|
+
Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
|
349
|
+
"""
|
350
|
+
pass
|
351
|
+
|
352
|
+
@abc.abstractmethod
|
353
|
+
def _merge(self, their: "_BatchedPenalizer"):
|
354
|
+
"""
|
355
|
+
Merge the penalizer with another penalizer.
|
356
|
+
"""
|
357
|
+
pass
|
@@ -0,0 +1,80 @@
|
|
1
|
+
import typing
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from ..orchestrator import _BatchedPenalizer, _TokenIDs
|
6
|
+
|
7
|
+
|
8
|
+
class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
9
|
+
"""
|
10
|
+
Frequency penalizer penalizes tokens based on their frequency in the output.
|
11
|
+
"""
|
12
|
+
|
13
|
+
frequency_penalties: torch.Tensor = None
|
14
|
+
cumulated_frequency_penalties: torch.Tensor = None
|
15
|
+
|
16
|
+
def _is_required(self) -> bool:
|
17
|
+
return any(
|
18
|
+
req.sampling_params.frequency_penalty != 0.0
|
19
|
+
for req in self.orchestrator.reqs()
|
20
|
+
)
|
21
|
+
|
22
|
+
def _prepare(self):
|
23
|
+
self.cumulated_frequency_penalties = (
|
24
|
+
torch.tensor(
|
25
|
+
data=[0.0 for _ in self.orchestrator.reqs()],
|
26
|
+
dtype=torch.float32,
|
27
|
+
device=self.orchestrator.device,
|
28
|
+
)
|
29
|
+
.unsqueeze_(1)
|
30
|
+
.repeat(1, self.orchestrator.vocab_size)
|
31
|
+
)
|
32
|
+
|
33
|
+
self.frequency_penalties = (
|
34
|
+
torch.tensor(
|
35
|
+
data=[
|
36
|
+
req.sampling_params.frequency_penalty
|
37
|
+
for req in self.orchestrator.reqs()
|
38
|
+
],
|
39
|
+
dtype=torch.float32,
|
40
|
+
device=self.orchestrator.device,
|
41
|
+
)
|
42
|
+
.unsqueeze_(1)
|
43
|
+
.expand_as(self.cumulated_frequency_penalties)
|
44
|
+
)
|
45
|
+
|
46
|
+
def _teardown(self):
|
47
|
+
del self.frequency_penalties
|
48
|
+
del self.cumulated_frequency_penalties
|
49
|
+
|
50
|
+
self.frequency_penalties = None
|
51
|
+
self.cumulated_frequency_penalties = None
|
52
|
+
|
53
|
+
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
54
|
+
pass
|
55
|
+
|
56
|
+
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
57
|
+
self.cumulated_frequency_penalties += (
|
58
|
+
self.frequency_penalties * output_ids.occurrence_count()
|
59
|
+
)
|
60
|
+
|
61
|
+
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
62
|
+
logits -= self.cumulated_frequency_penalties
|
63
|
+
return logits
|
64
|
+
|
65
|
+
def _filter(
|
66
|
+
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
67
|
+
):
|
68
|
+
self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]
|
69
|
+
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
|
70
|
+
indices_tensor_to_keep
|
71
|
+
]
|
72
|
+
|
73
|
+
def _merge(self, their: "BatchedFrequencyPenalizer"):
|
74
|
+
self.frequency_penalties = torch.cat(
|
75
|
+
[self.frequency_penalties, their.frequency_penalties], dim=0
|
76
|
+
)
|
77
|
+
self.cumulated_frequency_penalties = torch.cat(
|
78
|
+
[self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
|
79
|
+
dim=0,
|
80
|
+
)
|
@@ -0,0 +1,105 @@
|
|
1
|
+
import typing
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from ..orchestrator import _BatchedPenalizer, _TokenIDs
|
6
|
+
|
7
|
+
|
8
|
+
class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
9
|
+
"""
|
10
|
+
Min new tokens penalizer penalizes tokens based on the length of the output.
|
11
|
+
"""
|
12
|
+
|
13
|
+
min_new_tokens: torch.Tensor = None
|
14
|
+
stop_token_penalties: torch.Tensor = None
|
15
|
+
len_output_tokens: torch.Tensor = None
|
16
|
+
|
17
|
+
def _is_required(self) -> bool:
|
18
|
+
return any(
|
19
|
+
req.sampling_params.min_new_tokens > 0 for req in self.orchestrator.reqs()
|
20
|
+
)
|
21
|
+
|
22
|
+
def _prepare(self):
|
23
|
+
self.min_new_tokens = torch.tensor(
|
24
|
+
data=[
|
25
|
+
req.sampling_params.min_new_tokens for req in self.orchestrator.reqs()
|
26
|
+
],
|
27
|
+
dtype=torch.int32,
|
28
|
+
device=self.orchestrator.device,
|
29
|
+
).unsqueeze_(1)
|
30
|
+
|
31
|
+
padded_stop_token_ids = torch.nn.utils.rnn.pad_sequence(
|
32
|
+
sequences=[
|
33
|
+
torch.tensor(
|
34
|
+
data=list(
|
35
|
+
req.sampling_params.stop_token_ids
|
36
|
+
| {req.tokenizer.eos_token_id}
|
37
|
+
),
|
38
|
+
dtype=torch.int64,
|
39
|
+
device=self.orchestrator.device,
|
40
|
+
)
|
41
|
+
for req in self.orchestrator.reqs()
|
42
|
+
],
|
43
|
+
batch_first=True,
|
44
|
+
padding_value=self.orchestrator.vocab_size,
|
45
|
+
)
|
46
|
+
self.stop_token_penalties = torch.zeros(
|
47
|
+
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
|
48
|
+
dtype=torch.float32,
|
49
|
+
device=self.orchestrator.device,
|
50
|
+
).scatter_add_(
|
51
|
+
dim=1,
|
52
|
+
index=padded_stop_token_ids,
|
53
|
+
src=torch.full_like(
|
54
|
+
input=padded_stop_token_ids,
|
55
|
+
dtype=torch.float32,
|
56
|
+
fill_value=float("-inf"),
|
57
|
+
device=self.orchestrator.device,
|
58
|
+
),
|
59
|
+
)[
|
60
|
+
:, : self.orchestrator.vocab_size
|
61
|
+
]
|
62
|
+
|
63
|
+
self.len_output_tokens = torch.zeros(
|
64
|
+
size=(self.orchestrator.batch_size(), 1),
|
65
|
+
dtype=torch.int32,
|
66
|
+
device=self.orchestrator.device,
|
67
|
+
)
|
68
|
+
|
69
|
+
def _teardown(self):
|
70
|
+
del self.min_new_tokens
|
71
|
+
del self.stop_token_penalties
|
72
|
+
del self.len_output_tokens
|
73
|
+
|
74
|
+
self.min_new_tokens = None
|
75
|
+
self.stop_token_penalties = None
|
76
|
+
self.len_output_tokens = None
|
77
|
+
|
78
|
+
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
79
|
+
pass
|
80
|
+
|
81
|
+
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
82
|
+
self.len_output_tokens += 1
|
83
|
+
|
84
|
+
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
85
|
+
mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits)
|
86
|
+
logits[mask] += self.stop_token_penalties[mask]
|
87
|
+
return logits
|
88
|
+
|
89
|
+
def _filter(
|
90
|
+
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
91
|
+
):
|
92
|
+
self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep]
|
93
|
+
self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep]
|
94
|
+
self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep]
|
95
|
+
|
96
|
+
def _merge(self, their: "BatchedMinNewTokensPenalizer"):
|
97
|
+
self.min_new_tokens = torch.cat(
|
98
|
+
[self.min_new_tokens, their.min_new_tokens], dim=0
|
99
|
+
)
|
100
|
+
self.stop_token_penalties = torch.cat(
|
101
|
+
[self.stop_token_penalties, their.stop_token_penalties], dim=0
|
102
|
+
)
|
103
|
+
self.len_output_tokens = torch.cat(
|
104
|
+
[self.len_output_tokens, their.len_output_tokens], dim=0
|
105
|
+
)
|