sglang 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/api.py +7 -1
  2. sglang/bench_latency.py +9 -6
  3. sglang/bench_serving.py +46 -22
  4. sglang/global_config.py +1 -1
  5. sglang/lang/backend/runtime_endpoint.py +60 -49
  6. sglang/lang/compiler.py +2 -2
  7. sglang/lang/interpreter.py +4 -2
  8. sglang/lang/ir.py +16 -7
  9. sglang/srt/constrained/base_tool_cache.py +1 -1
  10. sglang/srt/constrained/fsm_cache.py +12 -2
  11. sglang/srt/constrained/jump_forward.py +13 -2
  12. sglang/srt/layers/activation.py +32 -0
  13. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  14. sglang/srt/layers/extend_attention.py +9 -2
  15. sglang/srt/layers/fused_moe/__init__.py +1 -0
  16. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  17. sglang/srt/layers/fused_moe/layer.py +587 -0
  18. sglang/srt/layers/layernorm.py +65 -0
  19. sglang/srt/layers/logits_processor.py +7 -2
  20. sglang/srt/layers/pooler.py +50 -0
  21. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  22. sglang/srt/layers/radix_attention.py +40 -16
  23. sglang/srt/managers/detokenizer_manager.py +31 -9
  24. sglang/srt/managers/io_struct.py +63 -0
  25. sglang/srt/managers/policy_scheduler.py +173 -25
  26. sglang/srt/managers/schedule_batch.py +115 -97
  27. sglang/srt/managers/tokenizer_manager.py +194 -112
  28. sglang/srt/managers/tp_worker.py +290 -359
  29. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  30. sglang/srt/mem_cache/chunk_cache.py +43 -20
  31. sglang/srt/mem_cache/memory_pool.py +2 -2
  32. sglang/srt/mem_cache/radix_cache.py +74 -40
  33. sglang/srt/model_executor/cuda_graph_runner.py +71 -25
  34. sglang/srt/model_executor/forward_batch_info.py +293 -156
  35. sglang/srt/model_executor/model_runner.py +77 -57
  36. sglang/srt/models/chatglm.py +2 -2
  37. sglang/srt/models/commandr.py +1 -1
  38. sglang/srt/models/deepseek.py +2 -2
  39. sglang/srt/models/deepseek_v2.py +7 -6
  40. sglang/srt/models/gemma.py +1 -1
  41. sglang/srt/models/gemma2.py +11 -6
  42. sglang/srt/models/grok.py +50 -396
  43. sglang/srt/models/internlm2.py +2 -7
  44. sglang/srt/models/llama2.py +4 -4
  45. sglang/srt/models/llama_embedding.py +88 -0
  46. sglang/srt/models/minicpm.py +2 -2
  47. sglang/srt/models/mixtral.py +56 -254
  48. sglang/srt/models/mixtral_quant.py +1 -4
  49. sglang/srt/models/qwen.py +2 -2
  50. sglang/srt/models/qwen2.py +2 -2
  51. sglang/srt/models/qwen2_moe.py +2 -13
  52. sglang/srt/models/stablelm.py +1 -1
  53. sglang/srt/openai_api/adapter.py +187 -48
  54. sglang/srt/openai_api/protocol.py +37 -1
  55. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  56. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  57. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  58. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  59. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  60. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  61. sglang/srt/sampling_params.py +31 -8
  62. sglang/srt/server.py +91 -29
  63. sglang/srt/server_args.py +32 -19
  64. sglang/srt/utils.py +32 -15
  65. sglang/test/run_eval.py +10 -1
  66. sglang/test/runners.py +81 -73
  67. sglang/test/simple_eval_humaneval.py +2 -8
  68. sglang/test/simple_eval_mgsm.py +203 -0
  69. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  70. sglang/test/test_layernorm.py +60 -0
  71. sglang/test/test_programs.py +36 -7
  72. sglang/test/test_utils.py +24 -2
  73. sglang/utils.py +0 -1
  74. sglang/version.py +1 -1
  75. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/METADATA +33 -16
  76. sglang-0.2.13.dist-info/RECORD +112 -0
  77. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
  78. sglang/srt/layers/linear.py +0 -884
  79. sglang/srt/layers/quantization/__init__.py +0 -64
  80. sglang/srt/layers/quantization/fp8.py +0 -677
  81. sglang/srt/model_loader/model_loader.py +0 -292
  82. sglang/srt/model_loader/utils.py +0 -275
  83. sglang-0.2.11.dist-info/RECORD +0 -102
  84. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
  85. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,357 @@
1
+ import abc
2
+ import dataclasses
3
+ import typing
4
+
5
+ import torch
6
+
7
+
8
+ @dataclasses.dataclass
9
+ class _ReqLike:
10
+ origin_input_ids: typing.Union[torch.Tensor, typing.List[int]]
11
+
12
+
13
+ @dataclasses.dataclass
14
+ class _BatchLike:
15
+ reqs: typing.List[_ReqLike]
16
+
17
+ def batch_size(self):
18
+ return len(self.reqs)
19
+
20
+
21
+ class BatchedPenalizerOrchestrator:
22
+ batch: _BatchLike
23
+ device: str
24
+ vocab_size: int
25
+ penalizers: typing.Dict[typing.Type["_BatchedPenalizer"], "_BatchedPenalizer"]
26
+
27
+ def __init__(
28
+ self,
29
+ vocab_size: int,
30
+ batch: _BatchLike,
31
+ device: str,
32
+ Penalizers: typing.Set[typing.Type["_BatchedPenalizer"]],
33
+ ):
34
+ self.vocab_size = vocab_size
35
+ self.batch = batch
36
+ self.device = device
37
+
38
+ self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
39
+
40
+ for penalizer in self.penalizers.values():
41
+ penalizer.prepare_if_required()
42
+
43
+ self.cumulate_input_tokens(
44
+ input_ids=[req.origin_input_ids for req in self.reqs()]
45
+ )
46
+
47
+ def reqs(self):
48
+ return self.batch.reqs
49
+
50
+ def batch_size(self):
51
+ return self.batch.batch_size()
52
+
53
+ def cumulate_input_tokens(
54
+ self,
55
+ input_ids: typing.Union[
56
+ typing.List[torch.Tensor], typing.List[typing.List[int]]
57
+ ],
58
+ ):
59
+ """
60
+ Feed the input tokens to the penalizers.
61
+
62
+ Args:
63
+ input_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The input tokens.
64
+ """
65
+ token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
66
+
67
+ for penalizer in self.penalizers.values():
68
+ penalizer.cumulate_input_tokens(input_ids=token_ids)
69
+
70
+ def cumulate_output_tokens(
71
+ self,
72
+ output_ids: typing.Union[
73
+ typing.List[torch.Tensor], typing.List[typing.List[int]]
74
+ ],
75
+ ):
76
+ """
77
+ Feed the output tokens to the penalizers.
78
+
79
+ Args:
80
+ output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.
81
+ """
82
+ token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)
83
+
84
+ for penalizer in self.penalizers.values():
85
+ penalizer.cumulate_output_tokens(output_ids=token_ids)
86
+
87
+ def apply(self, logits: torch.Tensor) -> torch.Tensor:
88
+ """
89
+ Apply the penalizers to the logits.
90
+ Note that it may apply the penalizers in-place.
91
+
92
+ Args:
93
+ logits (torch.Tensor): The logits to apply the penalizers to.
94
+
95
+ Returns:
96
+ torch.Tensor: The logits after applying the penalizers.
97
+ """
98
+ for penalizer in self.penalizers.values():
99
+ logits = penalizer.apply(logits)
100
+
101
+ return logits
102
+
103
+ def filter(
104
+ self,
105
+ indices_to_keep: typing.List[int],
106
+ indices_tensor_to_keep: torch.Tensor = None,
107
+ ):
108
+ """
109
+ Filter the penalizers based on the indices to keep in the batch.
110
+
111
+ Args:
112
+ indices_to_keep (typing.List[int]): List of indices to keep in the batch.
113
+ indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
114
+ """
115
+ empty_indices = len(indices_to_keep) == 0
116
+
117
+ for penalizer in self.penalizers.values():
118
+ if not penalizer.is_required() or empty_indices:
119
+ penalizer.teardown()
120
+ else:
121
+ # create tensor index only when it's needed
122
+ if indices_tensor_to_keep is None:
123
+ indices_tensor_to_keep = torch.tensor(
124
+ indices_to_keep, dtype=torch.int32, device=self.device
125
+ )
126
+
127
+ penalizer.filter(
128
+ indices_to_keep=indices_to_keep,
129
+ indices_tensor_to_keep=indices_tensor_to_keep,
130
+ )
131
+
132
+ def merge(self, their: "BatchedPenalizerOrchestrator"):
133
+ """
134
+ Merge the penalizers of another orchestrator into this one.
135
+
136
+ Note that this function **must** be called _before_ self.batch.reqs is updated (filtered).
137
+ Each unprepared penalizers would have to be prepared (creating tensors, etc.) first before merging.
138
+ This step requires the original batch.reqs, before it gets merged with other batch.reqs.
139
+
140
+ Args:
141
+ their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.
142
+ """
143
+ if self.vocab_size != their.vocab_size:
144
+ raise ValueError(
145
+ f"vocab_size mismatch: {self.vocab_size} != {their.vocab_size}"
146
+ )
147
+
148
+ for Penalizer, their_penalizer in their.penalizers.items():
149
+ if Penalizer not in self.penalizers:
150
+ raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers")
151
+
152
+ self.penalizers[Penalizer].merge(their_penalizer)
153
+
154
+
155
+ class _TokenIDs:
156
+ """
157
+ A class that wraps token IDs to provide additional utility functions to penalizers.
158
+
159
+ Attributes:
160
+ orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
161
+ token_ids (typing.Union[torch.Tensor, typing.List[torch.Tensor]]): The token IDs.
162
+ cached_counts (torch.Tensor): The cached occurrence count tensor.
163
+ """
164
+
165
+ orchestrator: BatchedPenalizerOrchestrator
166
+ token_ids: typing.Union[torch.Tensor, typing.List[torch.Tensor]]
167
+ cached_counts: torch.Tensor = None
168
+
169
+ def __init__(
170
+ self,
171
+ orchestrator: BatchedPenalizerOrchestrator,
172
+ token_ids: typing.Union[
173
+ typing.List[torch.Tensor], typing.List[typing.List[int]]
174
+ ],
175
+ ):
176
+ self.orchestrator = orchestrator
177
+
178
+ if not isinstance(token_ids[0], torch.Tensor):
179
+ token_ids = [
180
+ torch.tensor(
181
+ data=ids, dtype=torch.int64, device=self.orchestrator.device
182
+ )
183
+ for ids in token_ids
184
+ ]
185
+
186
+ self.token_ids = token_ids
187
+
188
+ def occurrence_count(self) -> torch.Tensor:
189
+ """
190
+ Returns a tensor of shape (batch_size, vocab_size) where each element is the number of times the corresponding token appears in the batch.
191
+
192
+ Returns:
193
+ torch.Tensor: The occurrence count tensor.
194
+ """
195
+ if self.cached_counts is not None:
196
+ return self.cached_counts
197
+
198
+ token_ids = self.token_ids
199
+
200
+ if isinstance(token_ids, torch.Tensor):
201
+ token_ids = token_ids.unsqueeze(1)
202
+
203
+ # needs to be long to be used as index in scatter_add
204
+ if token_ids.dtype != torch.int64:
205
+ token_ids = token_ids.to(torch.int64)
206
+
207
+ padded_token_ids = torch.nn.utils.rnn.pad_sequence(
208
+ sequences=token_ids,
209
+ batch_first=True,
210
+ padding_value=self.orchestrator.vocab_size,
211
+ )
212
+
213
+ self.cached_counts = torch.zeros(
214
+ size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
215
+ dtype=torch.int64,
216
+ device=self.orchestrator.device,
217
+ ).scatter_add_(
218
+ dim=1,
219
+ index=padded_token_ids,
220
+ src=torch.ones_like(padded_token_ids),
221
+ )[
222
+ :, : self.orchestrator.vocab_size
223
+ ]
224
+
225
+ return self.cached_counts
226
+
227
+
228
+ class _BatchedPenalizer(abc.ABC):
229
+ """
230
+ An abstract class for a batched penalizer.
231
+ """
232
+
233
+ orchestrator: BatchedPenalizerOrchestrator
234
+ _is_prepared: bool = False
235
+
236
+ def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
237
+ self.orchestrator = orchestrator
238
+
239
+ def is_prepared(self) -> bool:
240
+ return self._is_prepared
241
+
242
+ def is_required(self) -> bool:
243
+ return self._is_required()
244
+
245
+ def prepare(self):
246
+ if not self.is_prepared():
247
+ self._prepare()
248
+ self._is_prepared = True
249
+
250
+ def prepare_if_required(self):
251
+ if self.is_required():
252
+ self.prepare()
253
+
254
+ def teardown(self):
255
+ if self.is_prepared():
256
+ self._teardown()
257
+ self._is_prepared = False
258
+
259
+ def cumulate_input_tokens(self, input_ids: _TokenIDs):
260
+ if not self.is_prepared():
261
+ return
262
+
263
+ self._cumulate_input_tokens(input_ids=input_ids)
264
+
265
+ def cumulate_output_tokens(self, output_ids: _TokenIDs):
266
+ if not self.is_prepared():
267
+ return
268
+
269
+ self._cumulate_output_tokens(output_ids=output_ids)
270
+
271
+ def apply(self, logits: torch.Tensor) -> torch.Tensor:
272
+ if not self.is_prepared():
273
+ return logits
274
+
275
+ return self._apply(logits=logits)
276
+
277
+ def filter(
278
+ self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
279
+ ):
280
+ if not self.is_prepared():
281
+ return
282
+
283
+ self._filter(
284
+ indices_to_keep=indices_to_keep,
285
+ indices_tensor_to_keep=indices_tensor_to_keep,
286
+ )
287
+
288
+ def merge(self, their: "_BatchedPenalizer"):
289
+ if not self.is_prepared() and not their.is_prepared():
290
+ return
291
+
292
+ self.prepare()
293
+ their.prepare()
294
+ self._merge(their)
295
+
296
+ @abc.abstractmethod
297
+ def _is_required(self) -> bool:
298
+ """
299
+ Check if the penalizer is required to be prepared.
300
+ """
301
+ pass
302
+
303
+ @abc.abstractmethod
304
+ def _prepare(self):
305
+ """
306
+ Prepare the penalizer.
307
+ Usually, this is where the penalizer initializes its tensors.
308
+ """
309
+ pass
310
+
311
+ @abc.abstractmethod
312
+ def _teardown(self):
313
+ """
314
+ Tear down the penalizer.
315
+ Usually, this is where the penalizer frees its tensors.
316
+ """
317
+ pass
318
+
319
+ @abc.abstractmethod
320
+ def _cumulate_input_tokens(self, input_ids: _TokenIDs):
321
+ """
322
+ Cumulate the input tokens.
323
+ Orchestrator will call this function to feed the input tokens to the penalizer.
324
+ """
325
+ pass
326
+
327
+ @abc.abstractmethod
328
+ def _cumulate_output_tokens(self, output_ids: _TokenIDs):
329
+ """
330
+ Cumulate the output tokens.
331
+ Orchestrator will call this function to feed the output tokens to the penalizer.
332
+ """
333
+ pass
334
+
335
+ @abc.abstractmethod
336
+ def _apply(self, logits: torch.Tensor) -> torch.Tensor:
337
+ """
338
+ Apply the penalizer to the logits.
339
+ Penalizers can modify the logits in-place if needed.
340
+ """
341
+ pass
342
+
343
+ @abc.abstractmethod
344
+ def _filter(
345
+ self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
346
+ ):
347
+ """
348
+ Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
349
+ """
350
+ pass
351
+
352
+ @abc.abstractmethod
353
+ def _merge(self, their: "_BatchedPenalizer"):
354
+ """
355
+ Merge the penalizer with another penalizer.
356
+ """
357
+ pass
@@ -0,0 +1,80 @@
1
+ import typing
2
+
3
+ import torch
4
+
5
+ from ..orchestrator import _BatchedPenalizer, _TokenIDs
6
+
7
+
8
+ class BatchedFrequencyPenalizer(_BatchedPenalizer):
9
+ """
10
+ Frequency penalizer penalizes tokens based on their frequency in the output.
11
+ """
12
+
13
+ frequency_penalties: torch.Tensor = None
14
+ cumulated_frequency_penalties: torch.Tensor = None
15
+
16
+ def _is_required(self) -> bool:
17
+ return any(
18
+ req.sampling_params.frequency_penalty != 0.0
19
+ for req in self.orchestrator.reqs()
20
+ )
21
+
22
+ def _prepare(self):
23
+ self.cumulated_frequency_penalties = (
24
+ torch.tensor(
25
+ data=[0.0 for _ in self.orchestrator.reqs()],
26
+ dtype=torch.float32,
27
+ device=self.orchestrator.device,
28
+ )
29
+ .unsqueeze_(1)
30
+ .repeat(1, self.orchestrator.vocab_size)
31
+ )
32
+
33
+ self.frequency_penalties = (
34
+ torch.tensor(
35
+ data=[
36
+ req.sampling_params.frequency_penalty
37
+ for req in self.orchestrator.reqs()
38
+ ],
39
+ dtype=torch.float32,
40
+ device=self.orchestrator.device,
41
+ )
42
+ .unsqueeze_(1)
43
+ .expand_as(self.cumulated_frequency_penalties)
44
+ )
45
+
46
+ def _teardown(self):
47
+ del self.frequency_penalties
48
+ del self.cumulated_frequency_penalties
49
+
50
+ self.frequency_penalties = None
51
+ self.cumulated_frequency_penalties = None
52
+
53
+ def _cumulate_input_tokens(self, input_ids: _TokenIDs):
54
+ pass
55
+
56
+ def _cumulate_output_tokens(self, output_ids: _TokenIDs):
57
+ self.cumulated_frequency_penalties += (
58
+ self.frequency_penalties * output_ids.occurrence_count()
59
+ )
60
+
61
+ def _apply(self, logits: torch.Tensor) -> torch.Tensor:
62
+ logits -= self.cumulated_frequency_penalties
63
+ return logits
64
+
65
+ def _filter(
66
+ self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
67
+ ):
68
+ self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]
69
+ self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
70
+ indices_tensor_to_keep
71
+ ]
72
+
73
+ def _merge(self, their: "BatchedFrequencyPenalizer"):
74
+ self.frequency_penalties = torch.cat(
75
+ [self.frequency_penalties, their.frequency_penalties], dim=0
76
+ )
77
+ self.cumulated_frequency_penalties = torch.cat(
78
+ [self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
79
+ dim=0,
80
+ )
@@ -0,0 +1,105 @@
1
+ import typing
2
+
3
+ import torch
4
+
5
+ from ..orchestrator import _BatchedPenalizer, _TokenIDs
6
+
7
+
8
+ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
9
+ """
10
+ Min new tokens penalizer penalizes tokens based on the length of the output.
11
+ """
12
+
13
+ min_new_tokens: torch.Tensor = None
14
+ stop_token_penalties: torch.Tensor = None
15
+ len_output_tokens: torch.Tensor = None
16
+
17
+ def _is_required(self) -> bool:
18
+ return any(
19
+ req.sampling_params.min_new_tokens > 0 for req in self.orchestrator.reqs()
20
+ )
21
+
22
+ def _prepare(self):
23
+ self.min_new_tokens = torch.tensor(
24
+ data=[
25
+ req.sampling_params.min_new_tokens for req in self.orchestrator.reqs()
26
+ ],
27
+ dtype=torch.int32,
28
+ device=self.orchestrator.device,
29
+ ).unsqueeze_(1)
30
+
31
+ padded_stop_token_ids = torch.nn.utils.rnn.pad_sequence(
32
+ sequences=[
33
+ torch.tensor(
34
+ data=list(
35
+ req.sampling_params.stop_token_ids
36
+ | {req.tokenizer.eos_token_id}
37
+ ),
38
+ dtype=torch.int64,
39
+ device=self.orchestrator.device,
40
+ )
41
+ for req in self.orchestrator.reqs()
42
+ ],
43
+ batch_first=True,
44
+ padding_value=self.orchestrator.vocab_size,
45
+ )
46
+ self.stop_token_penalties = torch.zeros(
47
+ size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
48
+ dtype=torch.float32,
49
+ device=self.orchestrator.device,
50
+ ).scatter_add_(
51
+ dim=1,
52
+ index=padded_stop_token_ids,
53
+ src=torch.full_like(
54
+ input=padded_stop_token_ids,
55
+ dtype=torch.float32,
56
+ fill_value=float("-inf"),
57
+ device=self.orchestrator.device,
58
+ ),
59
+ )[
60
+ :, : self.orchestrator.vocab_size
61
+ ]
62
+
63
+ self.len_output_tokens = torch.zeros(
64
+ size=(self.orchestrator.batch_size(), 1),
65
+ dtype=torch.int32,
66
+ device=self.orchestrator.device,
67
+ )
68
+
69
+ def _teardown(self):
70
+ del self.min_new_tokens
71
+ del self.stop_token_penalties
72
+ del self.len_output_tokens
73
+
74
+ self.min_new_tokens = None
75
+ self.stop_token_penalties = None
76
+ self.len_output_tokens = None
77
+
78
+ def _cumulate_input_tokens(self, input_ids: _TokenIDs):
79
+ pass
80
+
81
+ def _cumulate_output_tokens(self, output_ids: _TokenIDs):
82
+ self.len_output_tokens += 1
83
+
84
+ def _apply(self, logits: torch.Tensor) -> torch.Tensor:
85
+ mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits)
86
+ logits[mask] += self.stop_token_penalties[mask]
87
+ return logits
88
+
89
+ def _filter(
90
+ self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
91
+ ):
92
+ self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep]
93
+ self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep]
94
+ self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep]
95
+
96
+ def _merge(self, their: "BatchedMinNewTokensPenalizer"):
97
+ self.min_new_tokens = torch.cat(
98
+ [self.min_new_tokens, their.min_new_tokens], dim=0
99
+ )
100
+ self.stop_token_penalties = torch.cat(
101
+ [self.stop_token_penalties, their.stop_token_penalties], dim=0
102
+ )
103
+ self.len_output_tokens = torch.cat(
104
+ [self.len_output_tokens, their.len_output_tokens], dim=0
105
+ )
@@ -0,0 +1,79 @@
1
+ import typing
2
+
3
+ import torch
4
+
5
+ from ..orchestrator import _BatchedPenalizer, _TokenIDs
6
+
7
+
8
+ class BatchedPresencePenalizer(_BatchedPenalizer):
9
+ """
10
+ Presence penalizer penalizes tokens based on their presence in the output.
11
+ """
12
+
13
+ presence_penalties: torch.Tensor = None
14
+ cumulated_presence_penalties: torch.Tensor = None
15
+
16
+ def _is_required(self) -> bool:
17
+ return any(
18
+ req.sampling_params.presence_penalty != 0.0
19
+ for req in self.orchestrator.reqs()
20
+ )
21
+
22
+ def _prepare(self):
23
+ self.cumulated_presence_penalties = (
24
+ torch.tensor(
25
+ data=[0.0 for _ in self.orchestrator.reqs()],
26
+ dtype=torch.float32,
27
+ device=self.orchestrator.device,
28
+ )
29
+ .unsqueeze_(1)
30
+ .repeat(1, self.orchestrator.vocab_size)
31
+ )
32
+
33
+ self.presence_penalties = (
34
+ torch.tensor(
35
+ data=[
36
+ req.sampling_params.presence_penalty
37
+ for req in self.orchestrator.reqs()
38
+ ],
39
+ dtype=torch.float32,
40
+ device=self.orchestrator.device,
41
+ )
42
+ .unsqueeze_(1)
43
+ .expand_as(self.cumulated_presence_penalties)
44
+ )
45
+
46
+ def _teardown(self):
47
+ del self.presence_penalties
48
+ del self.cumulated_presence_penalties
49
+
50
+ self.presence_penalties = None
51
+ self.cumulated_presence_penalties = None
52
+
53
+ def _cumulate_input_tokens(self, input_ids: _TokenIDs):
54
+ pass
55
+
56
+ def _cumulate_output_tokens(self, output_ids: _TokenIDs):
57
+ mask = output_ids.occurrence_count() > 0
58
+ self.cumulated_presence_penalties[mask] = self.presence_penalties[mask]
59
+
60
+ def _apply(self, logits: torch.Tensor) -> torch.Tensor:
61
+ logits -= self.cumulated_presence_penalties
62
+ return logits
63
+
64
+ def _filter(
65
+ self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
66
+ ):
67
+ self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
68
+ self.cumulated_presence_penalties = self.cumulated_presence_penalties[
69
+ indices_tensor_to_keep
70
+ ]
71
+
72
+ def _merge(self, their: "BatchedPresencePenalizer"):
73
+ self.presence_penalties = torch.cat(
74
+ [self.presence_penalties, their.presence_penalties], dim=0
75
+ )
76
+ self.cumulated_presence_penalties = torch.cat(
77
+ [self.cumulated_presence_penalties, their.cumulated_presence_penalties],
78
+ dim=0,
79
+ )
@@ -0,0 +1,83 @@
1
+ import typing
2
+
3
+ import torch
4
+
5
+ from ..orchestrator import _BatchedPenalizer, _TokenIDs
6
+
7
+
8
+ class BatchedRepetitionPenalizer(_BatchedPenalizer):
9
+ """
10
+ Repetition penalizer penalizes tokens based on their repetition in the input and output.
11
+ """
12
+
13
+ repetition_penalties: torch.Tensor = None
14
+ cumulated_repetition_penalties: torch.Tensor = None
15
+
16
+ def _is_required(self) -> bool:
17
+ return any(
18
+ req.sampling_params.repetition_penalty != 1.0
19
+ for req in self.orchestrator.reqs()
20
+ )
21
+
22
+ def _prepare(self):
23
+ self.cumulated_repetition_penalties = (
24
+ torch.tensor(
25
+ data=[1.0 for _ in self.orchestrator.reqs()],
26
+ dtype=torch.float32,
27
+ device=self.orchestrator.device,
28
+ )
29
+ .unsqueeze_(1)
30
+ .repeat(1, self.orchestrator.vocab_size)
31
+ )
32
+
33
+ self.repetition_penalties = (
34
+ torch.tensor(
35
+ data=[
36
+ req.sampling_params.repetition_penalty
37
+ for req in self.orchestrator.reqs()
38
+ ],
39
+ dtype=torch.float32,
40
+ device=self.orchestrator.device,
41
+ )
42
+ .unsqueeze_(1)
43
+ .expand_as(self.cumulated_repetition_penalties)
44
+ )
45
+
46
+ def _teardown(self):
47
+ del self.repetition_penalties
48
+ del self.cumulated_repetition_penalties
49
+
50
+ self.repetition_penalties = None
51
+ self.cumulated_repetition_penalties = None
52
+
53
+ def _cumulate_input_tokens(self, input_ids: _TokenIDs):
54
+ mask = input_ids.occurrence_count() > 0
55
+ self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
56
+
57
+ def _cumulate_output_tokens(self, output_ids: _TokenIDs):
58
+ mask = output_ids.occurrence_count() > 0
59
+ self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
60
+
61
+ def _apply(self, logits: torch.Tensor) -> torch.Tensor:
62
+ return torch.where(
63
+ logits > 0,
64
+ logits / self.cumulated_repetition_penalties,
65
+ logits * self.cumulated_repetition_penalties,
66
+ )
67
+
68
+ def _filter(
69
+ self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
70
+ ):
71
+ self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
72
+ self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
73
+ indices_tensor_to_keep
74
+ ]
75
+
76
+ def _merge(self, their: "BatchedRepetitionPenalizer"):
77
+ self.repetition_penalties = torch.cat(
78
+ [self.repetition_penalties, their.repetition_penalties], dim=0
79
+ )
80
+ self.cumulated_repetition_penalties = torch.cat(
81
+ [self.cumulated_repetition_penalties, their.cumulated_repetition_penalties],
82
+ dim=0,
83
+ )