divergent-beamsearch 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,132 +1,205 @@
1
- import math
2
- import torch
3
- from transformers import GPT2LMHeadModel
4
- from multi_choices_parser import MultiChoicesParser, end_symb
5
-
6
-
7
- def get_parsers_tokens(parsers : list[MultiChoicesParser]) -> tuple[list, list[int]]:
8
- parsers_tokens = []
9
- can_end = []
10
- for parser in parsers:
11
- tokens = list(parser.next())
12
- if end_symb in tokens:
13
- can_end.append(True)
14
- tokens.remove(end_symb)
15
- else:
16
- can_end.append(False)
17
- parsers_tokens.append(tokens)
18
- return parsers_tokens, can_end
19
-
20
- def apply_mask_tokens(pred : torch.Tensor, parsers_tokens):
21
- mask = torch.ones_like(pred, dtype=torch.bool)
22
- for tokens in parsers_tokens:
23
- mask[:, tokens] = False
24
- pred[mask] = -float('inf')
25
- return pred[~pred.isinf().all(dim=-1)]
26
-
27
-
28
- def batched_inference_logits(model : GPT2LMHeadModel, input_ids : torch.Tensor, batch_size : int = 32) -> torch.Tensor:
29
- logits = []
30
- for i in range(0, input_ids.shape[0], batch_size):
31
- logits.append(model(input_ids[i:i+batch_size]).logits)
32
- return torch.cat(logits, dim=0)
33
-
34
- def select_mask(source : list, mask : list[bool]) -> list:
35
- assert len(source) == len(mask)
36
- return [x for x, m in zip(source, mask) if m]
37
-
38
-
39
- def log1mexp(x: torch.Tensor) -> torch.Tensor:
40
- """Numerically accurate evaluation of log(1 - exp(x)) for x < 0.
41
- See [Maechler2012accurate]_ for details.
42
- """
43
- mask = -math.log(2) < x # x < 0
44
- return torch.where(
45
- mask,
46
- (-x.expm1()).log(),
47
- (-x.exp()).log1p(),
48
- )
49
-
50
- class AcceptEverythingParser:
51
- def __init__(self, vocab_size : int):
52
- self.vocab_size = vocab_size
53
- self.tokens = tuple(range(vocab_size))
54
-
55
- def step(self, token):
56
- pass
57
-
58
- def next(self):
59
- return self.tokens
60
-
61
- def copy(self):
62
- return self
63
-
64
- @torch.no_grad()
65
- def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, multi_choices_parser : MultiChoicesParser, pad_token_id : int, batch_size=32, num_solutions = None) -> tuple[torch.Tensor, torch.Tensor]:
66
- assert input_ids.shape[0] == 1, "Batch size must be 1"
67
-
68
- if num_solutions is None:
69
- num_solutions = beam_size
70
- vanilla = multi_choices_parser is None
71
- if vanilla:
72
- multi_choices_parser = AcceptEverythingParser(model.config.vocab_size)
73
-
74
- parsers_unfinished = [multi_choices_parser]
75
- scores_finished = torch.tensor([], dtype=torch.float)
76
- solutions_finished = torch.tensor([], dtype=torch.long).view(0,0)
77
-
78
- input_ids_unfinished = input_ids
79
- scores_unfinished = torch.tensor([0.0], dtype=torch.float)
80
- solutions_unfinished = torch.tensor([], dtype=torch.long).view(1,0)
81
-
82
-
83
- for _ in range(max_length):
84
- if len(input_ids_unfinished) == 0:
85
- break
86
- pred = batched_inference_logits(model, input_ids_unfinished, batch_size)[:, -1].cpu()
87
- parsers_tokens, can_end = get_parsers_tokens(parsers_unfinished)
88
- # input_ids_unfinished = input_ids_unfinished[~torch.tensor(can_only_end)]
89
- logprobs = torch.log_softmax(pred, dim=-1)
90
- logprobs_filtered = apply_mask_tokens(logprobs, parsers_tokens)
91
- if len(logprobs_filtered):
92
- topk = torch.topk(logprobs_filtered, beam_size, dim=-1) # shape (batch_size, beam_size)
93
- values = topk.values + scores_unfinished.unsqueeze(-1)
94
- topk_global = values.flatten().topk(beam_size)
95
- best_tokens_row = topk_global.indices // beam_size
96
- best_tokens, best_tokens_logprobs = topk.indices[best_tokens_row, topk_global.indices % beam_size], topk.values[best_tokens_row, topk_global.indices % beam_size]
97
- notinf = ~best_tokens_logprobs.isinf()
98
- best_tokens, best_tokens_row, best_tokens_logprobs = best_tokens[notinf], best_tokens_row[notinf], best_tokens_logprobs[notinf]
99
- else:
100
- best_tokens = torch.tensor([], dtype=torch.long)
101
- best_tokens_row = torch.tensor([], dtype=torch.long)
102
- best_tokens_logprobs = torch.tensor([], dtype=torch.float)
103
-
104
-
105
- scores_finished_current = scores_unfinished[can_end]
106
- solutions_finished_current = solutions_unfinished[can_end]
107
- scores_finished_current = scores_finished_current + log1mexp(logprobs[can_end, select_mask(parsers_tokens, can_end)].logsumexp(dim=-1)).squeeze(-1)
108
- scores_finished = torch.cat([scores_finished, scores_finished_current])
109
- if len(solutions_finished_current):
110
- pad = torch.full((len(scores_finished_current), solutions_finished_current.shape[1] - solutions_finished.shape[1]), pad_token_id, dtype=torch.long)
111
- solutions_finished = torch.cat([solutions_finished.view(-1, solutions_finished_current.shape[1]+pad.shape[1]), torch.cat([solutions_finished_current, pad], dim=1)], dim=0)
112
- if solutions_finished.numel():
113
- # Keep num_solutions best solutions in finished
114
- order = scores_finished.argsort(descending=True)
115
- solutions_finished = solutions_finished[order][:num_solutions]
116
- scores_finished = scores_finished[order][:num_solutions]
117
-
118
-
119
- input_ids_unfinished = torch.cat([input_ids_unfinished[best_tokens_row], best_tokens.unsqueeze(-1)], dim=-1)
120
- scores_unfinished = scores_unfinished[best_tokens_row] + best_tokens_logprobs
121
- solutions_unfinished = torch.cat([solutions_unfinished[best_tokens_row], best_tokens.unsqueeze(-1)], dim=-1)
122
- parsers_unfinished = [parsers_unfinished[row].copy() for row in best_tokens_row]
123
- for parser, token in zip(parsers_unfinished, best_tokens.tolist()):
124
- parser.step(token)
125
-
126
- # Special case of vanilla beam search where all answers are valid
127
- if vanilla:
128
- order = scores_unfinished.argsort(descending=True)
129
- scores_finished = scores_unfinished[order][:num_solutions]
130
- solutions_finished = solutions_unfinished[order][:num_solutions]
131
-
132
- return scores_finished, solutions_finished
1
+ import math
2
+ import torch
3
+ from transformers import GPT2LMHeadModel
4
+ from multi_choices_parser import MultiChoicesParser, end_symb
5
+
6
+
7
+ class Parser:
8
+ def step(self, token):
9
+ raise NotImplementedError
10
+
11
+ def next(self):
12
+ raise NotImplementedError
13
+
14
+ def copy(self):
15
+ raise NotImplementedError
16
+
17
+ def get_parsers_tokens(parsers : list[Parser]) -> tuple[list, list[int]]:
18
+ parsers_tokens = []
19
+ can_end = []
20
+ for parser in parsers:
21
+ tokens = list(parser.next())
22
+ if end_symb in tokens:
23
+ can_end.append(True)
24
+ tokens.remove(end_symb)
25
+ else:
26
+ can_end.append(False)
27
+ parsers_tokens.append(tokens)
28
+ return parsers_tokens, can_end
29
+
30
+ def apply_mask_tokens(pred : torch.Tensor, parsers_tokens):
31
+ mask = torch.ones_like(pred, dtype=torch.bool)
32
+ for tokens in parsers_tokens:
33
+ mask[:, tokens] = False
34
+ pred[mask] = -float('inf')
35
+ return pred[~pred.isinf().all(dim=-1)]
36
+
37
+
38
+ def batched_inference_logits(model : GPT2LMHeadModel, input_ids : torch.Tensor, attention_mask : torch.Tensor | None = None, batch_size : int = 32) -> torch.Tensor:
39
+ logits = []
40
+ if attention_mask is None:
41
+ attention_mask = torch.ones_like(input_ids)
42
+ for i in range(0, input_ids.shape[0], batch_size):
43
+ logits.append(model(input_ids[i:i+batch_size], attention_mask=attention_mask[i:i+batch_size]).logits)
44
+ return torch.cat(logits, dim=0)
45
+
46
+ def select_mask(source : list, mask : list[bool]) -> list:
47
+ assert len(source) == len(mask)
48
+ return [x for x, m in zip(source, mask) if m]
49
+
50
+
51
+ def log1mexp(x: torch.Tensor) -> torch.Tensor:
52
+ """Numerically accurate evaluation of log(1 - exp(x)) for x < 0.
53
+ See [Maechler2012accurate]_ for details.
54
+ """
55
+ mask = -math.log(2) < x # x < 0
56
+ return torch.where(
57
+ mask,
58
+ (-x.expm1()).log(),
59
+ (-x.exp()).log1p(),
60
+ )
61
+
62
+
63
+
64
+
65
+ class AcceptEverythingParser(Parser):
66
+ def __init__(self, vocab_size : int):
67
+ self.vocab_size = vocab_size
68
+ self.tokens = tuple(range(vocab_size))
69
+
70
+ def step(self, token):
71
+ pass
72
+
73
+ def next(self):
74
+ return self.tokens
75
+
76
+ def copy(self):
77
+ return self
78
+
79
+ @torch.no_grad()
80
+ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, parser : Parser, pad_token_id : int, batch_size=32, num_solutions = None) -> tuple[torch.Tensor, torch.Tensor]:
81
+ assert input_ids.shape[0] == 1, "Batch size must be 1"
82
+ device = input_ids.device
83
+ input_ids = input_ids.cpu()
84
+
85
+ if num_solutions is None:
86
+ num_solutions = beam_size
87
+ vanilla = parser is None
88
+ if vanilla:
89
+ parser = AcceptEverythingParser(model.config.vocab_size)
90
+
91
+ parsers_unfinished = [parser]
92
+ scores_finished = torch.tensor([], dtype=torch.float)
93
+ solutions_finished = torch.tensor([], dtype=torch.long).view(0,0)
94
+
95
+ input_ids_unfinished = input_ids
96
+ scores_unfinished = torch.tensor([0.0], dtype=torch.float)
97
+ solutions_unfinished = torch.tensor([], dtype=torch.long).view(1,0)
98
+
99
+
100
+ for _ in range(max_length):
101
+ if len(input_ids_unfinished) == 0:
102
+ break
103
+ pred = batched_inference_logits(model, input_ids_unfinished.to(device), batch_size=batch_size)[:, -1].cpu()
104
+ parsers_tokens, can_end = get_parsers_tokens(parsers_unfinished)
105
+ logprobs = torch.log_softmax(pred, dim=-1)
106
+ logprobs_filtered = apply_mask_tokens(logprobs, parsers_tokens)
107
+ if len(logprobs_filtered):
108
+ topk = torch.topk(logprobs_filtered, beam_size, dim=-1) # shape (batch_size, beam_size)
109
+ values = topk.values + scores_unfinished.unsqueeze(-1)
110
+ topk_global = values.flatten().topk(beam_size)
111
+ best_tokens_row = topk_global.indices // beam_size
112
+ best_tokens, best_tokens_logprobs = topk.indices[best_tokens_row, topk_global.indices % beam_size], topk.values[best_tokens_row, topk_global.indices % beam_size]
113
+ notinf = ~best_tokens_logprobs.isinf()
114
+ best_tokens, best_tokens_row, best_tokens_logprobs = best_tokens[notinf], best_tokens_row[notinf], best_tokens_logprobs[notinf]
115
+ else:
116
+ best_tokens = torch.tensor([], dtype=torch.long)
117
+ best_tokens_row = torch.tensor([], dtype=torch.long)
118
+ best_tokens_logprobs = torch.tensor([], dtype=torch.float)
119
+
120
+
121
+ scores_finished_current = scores_unfinished[can_end]
122
+ solutions_finished_current = solutions_unfinished[can_end]
123
+ scores_finished_current = scores_finished_current + log1mexp(logprobs[can_end, select_mask(parsers_tokens, can_end)].logsumexp(dim=-1)).squeeze(-1)
124
+ scores_finished = torch.cat([scores_finished, scores_finished_current])
125
+ if len(solutions_finished_current):
126
+ pad = torch.full((len(scores_finished_current), solutions_finished_current.shape[1] - solutions_finished.shape[1]), pad_token_id, dtype=torch.long)
127
+ solutions_finished = torch.cat([solutions_finished.view(-1, solutions_finished_current.shape[1]+pad.shape[1]), torch.cat([solutions_finished_current, pad], dim=1)], dim=0)
128
+ if solutions_finished.numel():
129
+ # Keep num_solutions best solutions in finished
130
+ order = scores_finished.argsort(descending=True)
131
+ solutions_finished = solutions_finished[order][:num_solutions]
132
+ scores_finished = scores_finished[order][:num_solutions]
133
+
134
+
135
+ input_ids_unfinished = torch.cat([input_ids_unfinished[best_tokens_row], best_tokens.unsqueeze(-1)], dim=-1)
136
+ scores_unfinished = scores_unfinished[best_tokens_row] + best_tokens_logprobs
137
+ solutions_unfinished = torch.cat([solutions_unfinished[best_tokens_row], best_tokens.unsqueeze(-1)], dim=-1)
138
+ parsers_unfinished = [parsers_unfinished[row].copy() for row in best_tokens_row]
139
+ for parser, token in zip(parsers_unfinished, best_tokens.tolist()):
140
+ parser.step(token)
141
+
142
+ # Special case of vanilla beam search where all answers are valid
143
+ if vanilla:
144
+ order = scores_unfinished.argsort(descending=True)
145
+ scores_finished = scores_unfinished[order][:num_solutions]
146
+ solutions_finished = solutions_unfinished[order][:num_solutions]
147
+
148
+ return scores_finished, solutions_finished
149
+
150
+
151
+ def set_slice_row(x : torch.Tensor, slices : torch.IntTensor, value) -> torch.Tensor:
152
+ indices = [torch.arange(start, end) for start, end in slices]
153
+ for i in range(slices.size(0)):
154
+ x[i].index_fill_(0, indices[i], 0)
155
+
156
+ @torch.no_grad()
157
+ def divergent_logprob(input_ids : torch.Tensor, attention_mask : torch.Tensor | None, model : GPT2LMHeadModel, parsers : Parser | list[Parser] | None, batch_size=32, start : int | torch.IntTensor = None) -> torch.FloatTensor:
158
+ if start is None:
159
+ start = 0
160
+ if isinstance(start, int):
161
+ start = torch.tensor([start]*input_ids.shape[0])
162
+ assert start.shape[0] == input_ids.shape[0]
163
+ # -1 because next token offset
164
+ start = start - 1
165
+
166
+ if attention_mask is None:
167
+ attention_mask = torch.ones_like(input_ids)
168
+
169
+ logits = batched_inference_logits(model, input_ids, attention_mask, batch_size).cpu()
170
+ input_ids = input_ids.cpu()
171
+ attention_mask = attention_mask.cpu()
172
+
173
+ logsoftmax = torch.log_softmax(logits, dim=-1)
174
+ log_probs = torch.gather(
175
+ logsoftmax[:, :-1, :], 2, input_ids[:, 1:, None]
176
+ ).squeeze(-1)
177
+ mask = attention_mask[:, 1:].cpu().clone()
178
+
179
+ input_len = attention_mask.sum(-1)
180
+ pos = torch.stack([torch.zeros_like(start), start], dim=-1)
181
+ pos_anti = pos.flip(1)
182
+ pos_anti[:, -1] = input_len
183
+ set_slice_row(mask, pos, 0)
184
+ vanilla_prob = (log_probs * mask).sum(-1)
185
+ if parsers is None:
186
+ parsers = AcceptEverythingParser(model.config.vocab_size)
187
+ if not isinstance(parsers, (tuple, list)):
188
+ parsers = [parsers.copy() for _ in range(len(input_ids))]
189
+ next_possible_tokens = []
190
+ for i, parser in enumerate(parsers):
191
+ # +1 because no next-token offset
192
+ start = pos_anti[i,0]+1
193
+ for input_id, att in zip(input_ids[i, start:].tolist(), attention_mask[i, start:].tolist()):
194
+ if not att:
195
+ break
196
+ parser.step(input_id)
197
+ next_tokens = list(parser.next())
198
+ try:
199
+ next_tokens.remove(end_symb)
200
+ except ValueError:
201
+ pass
202
+ next_possible_tokens.append(next_tokens)
203
+ last_token_log_probs = torch.stack([log1mexp(logsoftmax[i, input_len[i]-1, tokens].logsumexp(-1)).squeeze() for i, tokens in enumerate(next_possible_tokens)])
204
+ prob = vanilla_prob + last_token_log_probs
205
+ return prob
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: divergent-beamsearch
3
- Version: 0.1.1
3
+ Version: 0.1.3
4
4
  Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
5
5
  License-File: LICENCE
6
6
  Requires-Python: >=3.11
7
7
  Requires-Dist: multi-choices-parser>=0.9.57
8
- Requires-Dist: torch>=2.5.1
8
+ Requires-Dist: torch>=2.0.0
9
9
  Requires-Dist: transformers>=4.47.1
10
10
  Description-Content-Type: text/markdown
11
11
 
@@ -0,0 +1,6 @@
1
+ divergent_beamsearch/__init__.py,sha256=Z2R1pkj4EEHMKWVZX0upeE_Jtfb6joxgYHuRNxWc8Zo,43
2
+ divergent_beamsearch/algorithm.py,sha256=d0xU4OWiCEa5icdXZHoV1P-eKYftYMHhfBZMEVNkRXQ,8649
3
+ divergent_beamsearch-0.1.3.dist-info/METADATA,sha256=waQn6dvg12V9753CcIQlOR_jcOvfbwAJa24FvR5awy0,2826
4
+ divergent_beamsearch-0.1.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ divergent_beamsearch-0.1.3.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
6
+ divergent_beamsearch-0.1.3.dist-info/RECORD,,
@@ -1,21 +1,21 @@
1
- MIT License
2
-
3
- Copyright (c) 2025 Hichem Ammar Khodja
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Hichem Ammar Khodja
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
21
  SOFTWARE.
@@ -1,6 +0,0 @@
1
- divergent_beamsearch/__init__.py,sha256=Z2R1pkj4EEHMKWVZX0upeE_Jtfb6joxgYHuRNxWc8Zo,43
2
- divergent_beamsearch/algorithm.py,sha256=0NvVocEHVlIAXnfKhiUW6PEbG_L7uBgE7NGJtaoJ-Rw,6136
3
- divergent_beamsearch-0.1.1.dist-info/METADATA,sha256=dFlRtT8pvNDcUDZaac59zsLAWHB5M5maMkPO-DKFDGI,2826
4
- divergent_beamsearch-0.1.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- divergent_beamsearch-0.1.1.dist-info/licenses/LICENCE,sha256=jDQOOFKJxgrQwcEyipwKcKzj5IX_paD_41c3iOjH3qw,1095
6
- divergent_beamsearch-0.1.1.dist-info/RECORD,,