divergent-beamsearch 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- divergent_beamsearch/algorithm.py +146 -132
- {divergent_beamsearch-0.1.1.dist-info → divergent_beamsearch-0.1.2.dist-info}/METADATA +2 -2
- divergent_beamsearch-0.1.2.dist-info/RECORD +6 -0
- {divergent_beamsearch-0.1.1.dist-info → divergent_beamsearch-0.1.2.dist-info}/licenses/LICENCE +20 -20
- divergent_beamsearch-0.1.1.dist-info/RECORD +0 -6
- {divergent_beamsearch-0.1.1.dist-info → divergent_beamsearch-0.1.2.dist-info}/WHEEL +0 -0
@@ -1,132 +1,146 @@
|
|
1
|
-
import math
|
2
|
-
import torch
|
3
|
-
from transformers import GPT2LMHeadModel
|
4
|
-
from multi_choices_parser import MultiChoicesParser, end_symb
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
1
|
+
import math
|
2
|
+
import torch
|
3
|
+
from transformers import GPT2LMHeadModel
|
4
|
+
from multi_choices_parser import MultiChoicesParser, end_symb
|
5
|
+
|
6
|
+
|
7
|
+
class Parser:
|
8
|
+
def step(self, token):
|
9
|
+
raise NotImplementedError
|
10
|
+
|
11
|
+
def next(self):
|
12
|
+
raise NotImplementedError
|
13
|
+
|
14
|
+
def copy(self):
|
15
|
+
raise NotImplementedError
|
16
|
+
|
17
|
+
def get_parsers_tokens(parsers : list[Parser]) -> tuple[list, list[int]]:
|
18
|
+
parsers_tokens = []
|
19
|
+
can_end = []
|
20
|
+
for parser in parsers:
|
21
|
+
tokens = list(parser.next())
|
22
|
+
if end_symb in tokens:
|
23
|
+
can_end.append(True)
|
24
|
+
tokens.remove(end_symb)
|
25
|
+
else:
|
26
|
+
can_end.append(False)
|
27
|
+
parsers_tokens.append(tokens)
|
28
|
+
return parsers_tokens, can_end
|
29
|
+
|
30
|
+
def apply_mask_tokens(pred : torch.Tensor, parsers_tokens):
|
31
|
+
mask = torch.ones_like(pred, dtype=torch.bool)
|
32
|
+
for tokens in parsers_tokens:
|
33
|
+
mask[:, tokens] = False
|
34
|
+
pred[mask] = -float('inf')
|
35
|
+
return pred[~pred.isinf().all(dim=-1)]
|
36
|
+
|
37
|
+
|
38
|
+
def batched_inference_logits(model : GPT2LMHeadModel, input_ids : torch.Tensor, batch_size : int = 32) -> torch.Tensor:
|
39
|
+
logits = []
|
40
|
+
for i in range(0, input_ids.shape[0], batch_size):
|
41
|
+
logits.append(model(input_ids[i:i+batch_size]).logits)
|
42
|
+
return torch.cat(logits, dim=0)
|
43
|
+
|
44
|
+
def select_mask(source : list, mask : list[bool]) -> list:
|
45
|
+
assert len(source) == len(mask)
|
46
|
+
return [x for x, m in zip(source, mask) if m]
|
47
|
+
|
48
|
+
|
49
|
+
def log1mexp(x: torch.Tensor) -> torch.Tensor:
|
50
|
+
"""Numerically accurate evaluation of log(1 - exp(x)) for x < 0.
|
51
|
+
See [Maechler2012accurate]_ for details.
|
52
|
+
"""
|
53
|
+
mask = -math.log(2) < x # x < 0
|
54
|
+
return torch.where(
|
55
|
+
mask,
|
56
|
+
(-x.expm1()).log(),
|
57
|
+
(-x.exp()).log1p(),
|
58
|
+
)
|
59
|
+
|
60
|
+
|
61
|
+
|
62
|
+
|
63
|
+
class AcceptEverythingParser(Parser):
|
64
|
+
def __init__(self, vocab_size : int):
|
65
|
+
self.vocab_size = vocab_size
|
66
|
+
self.tokens = tuple(range(vocab_size))
|
67
|
+
|
68
|
+
def step(self, token):
|
69
|
+
pass
|
70
|
+
|
71
|
+
def next(self):
|
72
|
+
return self.tokens
|
73
|
+
|
74
|
+
def copy(self):
|
75
|
+
return self
|
76
|
+
|
77
|
+
@torch.no_grad()
|
78
|
+
def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, parser : Parser, pad_token_id : int, batch_size=32, num_solutions = None) -> tuple[torch.Tensor, torch.Tensor]:
|
79
|
+
assert input_ids.shape[0] == 1, "Batch size must be 1"
|
80
|
+
device = input_ids.device
|
81
|
+
input_ids = input_ids.cpu()
|
82
|
+
|
83
|
+
if num_solutions is None:
|
84
|
+
num_solutions = beam_size
|
85
|
+
vanilla = parser is None
|
86
|
+
if vanilla:
|
87
|
+
parser = AcceptEverythingParser(model.config.vocab_size)
|
88
|
+
|
89
|
+
parsers_unfinished = [parser]
|
90
|
+
scores_finished = torch.tensor([], dtype=torch.float)
|
91
|
+
solutions_finished = torch.tensor([], dtype=torch.long).view(0,0)
|
92
|
+
|
93
|
+
input_ids_unfinished = input_ids
|
94
|
+
scores_unfinished = torch.tensor([0.0], dtype=torch.float)
|
95
|
+
solutions_unfinished = torch.tensor([], dtype=torch.long).view(1,0)
|
96
|
+
|
97
|
+
|
98
|
+
for _ in range(max_length):
|
99
|
+
if len(input_ids_unfinished) == 0:
|
100
|
+
break
|
101
|
+
pred = batched_inference_logits(model, input_ids_unfinished.to(device), batch_size)[:, -1].cpu()
|
102
|
+
parsers_tokens, can_end = get_parsers_tokens(parsers_unfinished)
|
103
|
+
logprobs = torch.log_softmax(pred, dim=-1)
|
104
|
+
logprobs_filtered = apply_mask_tokens(logprobs, parsers_tokens)
|
105
|
+
if len(logprobs_filtered):
|
106
|
+
topk = torch.topk(logprobs_filtered, beam_size, dim=-1) # shape (batch_size, beam_size)
|
107
|
+
values = topk.values + scores_unfinished.unsqueeze(-1)
|
108
|
+
topk_global = values.flatten().topk(beam_size)
|
109
|
+
best_tokens_row = topk_global.indices // beam_size
|
110
|
+
best_tokens, best_tokens_logprobs = topk.indices[best_tokens_row, topk_global.indices % beam_size], topk.values[best_tokens_row, topk_global.indices % beam_size]
|
111
|
+
notinf = ~best_tokens_logprobs.isinf()
|
112
|
+
best_tokens, best_tokens_row, best_tokens_logprobs = best_tokens[notinf], best_tokens_row[notinf], best_tokens_logprobs[notinf]
|
113
|
+
else:
|
114
|
+
best_tokens = torch.tensor([], dtype=torch.long)
|
115
|
+
best_tokens_row = torch.tensor([], dtype=torch.long)
|
116
|
+
best_tokens_logprobs = torch.tensor([], dtype=torch.float)
|
117
|
+
|
118
|
+
|
119
|
+
scores_finished_current = scores_unfinished[can_end]
|
120
|
+
solutions_finished_current = solutions_unfinished[can_end]
|
121
|
+
scores_finished_current = scores_finished_current + log1mexp(logprobs[can_end, select_mask(parsers_tokens, can_end)].logsumexp(dim=-1)).squeeze(-1)
|
122
|
+
scores_finished = torch.cat([scores_finished, scores_finished_current])
|
123
|
+
if len(solutions_finished_current):
|
124
|
+
pad = torch.full((len(scores_finished_current), solutions_finished_current.shape[1] - solutions_finished.shape[1]), pad_token_id, dtype=torch.long)
|
125
|
+
solutions_finished = torch.cat([solutions_finished.view(-1, solutions_finished_current.shape[1]+pad.shape[1]), torch.cat([solutions_finished_current, pad], dim=1)], dim=0)
|
126
|
+
if solutions_finished.numel():
|
127
|
+
# Keep num_solutions best solutions in finished
|
128
|
+
order = scores_finished.argsort(descending=True)
|
129
|
+
solutions_finished = solutions_finished[order][:num_solutions]
|
130
|
+
scores_finished = scores_finished[order][:num_solutions]
|
131
|
+
|
132
|
+
|
133
|
+
input_ids_unfinished = torch.cat([input_ids_unfinished[best_tokens_row], best_tokens.unsqueeze(-1)], dim=-1)
|
134
|
+
scores_unfinished = scores_unfinished[best_tokens_row] + best_tokens_logprobs
|
135
|
+
solutions_unfinished = torch.cat([solutions_unfinished[best_tokens_row], best_tokens.unsqueeze(-1)], dim=-1)
|
136
|
+
parsers_unfinished = [parsers_unfinished[row].copy() for row in best_tokens_row]
|
137
|
+
for parser, token in zip(parsers_unfinished, best_tokens.tolist()):
|
138
|
+
parser.step(token)
|
139
|
+
|
140
|
+
# Special case of vanilla beam search where all answers are valid
|
141
|
+
if vanilla:
|
142
|
+
order = scores_unfinished.argsort(descending=True)
|
143
|
+
scores_finished = scores_unfinished[order][:num_solutions]
|
144
|
+
solutions_finished = solutions_unfinished[order][:num_solutions]
|
145
|
+
|
146
|
+
return scores_finished, solutions_finished
|
@@ -1,11 +1,11 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: divergent-beamsearch
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.2
|
4
4
|
Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
|
5
5
|
License-File: LICENCE
|
6
6
|
Requires-Python: >=3.11
|
7
7
|
Requires-Dist: multi-choices-parser>=0.9.57
|
8
|
-
Requires-Dist: torch>=2.
|
8
|
+
Requires-Dist: torch>=2.0.0
|
9
9
|
Requires-Dist: transformers>=4.47.1
|
10
10
|
Description-Content-Type: text/markdown
|
11
11
|
|
@@ -0,0 +1,6 @@
|
|
1
|
+
divergent_beamsearch/__init__.py,sha256=Z2R1pkj4EEHMKWVZX0upeE_Jtfb6joxgYHuRNxWc8Zo,43
|
2
|
+
divergent_beamsearch/algorithm.py,sha256=w6aLDOnLwLabmHHOMCEx1Y8P8yaHkFEJxMGNw6f7RsU,6115
|
3
|
+
divergent_beamsearch-0.1.2.dist-info/METADATA,sha256=fjqa8W8RpRSZOFV-3_D28o2353p54mB7bhEUrsRSoxw,2826
|
4
|
+
divergent_beamsearch-0.1.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
+
divergent_beamsearch-0.1.2.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
|
6
|
+
divergent_beamsearch-0.1.2.dist-info/RECORD,,
|
{divergent_beamsearch-0.1.1.dist-info → divergent_beamsearch-0.1.2.dist-info}/licenses/LICENCE
RENAMED
@@ -1,21 +1,21 @@
|
|
1
|
-
MIT License
|
2
|
-
|
3
|
-
Copyright (c) 2025 Hichem Ammar Khodja
|
4
|
-
|
5
|
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
-
of this software and associated documentation files (the "Software"), to deal
|
7
|
-
in the Software without restriction, including without limitation the rights
|
8
|
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
-
copies of the Software, and to permit persons to whom the Software is
|
10
|
-
furnished to do so, subject to the following conditions:
|
11
|
-
|
12
|
-
The above copyright notice and this permission notice shall be included in all
|
13
|
-
copies or substantial portions of the Software.
|
14
|
-
|
15
|
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2025 Hichem Ammar Khodja
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
21
|
SOFTWARE.
|
@@ -1,6 +0,0 @@
|
|
1
|
-
divergent_beamsearch/__init__.py,sha256=Z2R1pkj4EEHMKWVZX0upeE_Jtfb6joxgYHuRNxWc8Zo,43
|
2
|
-
divergent_beamsearch/algorithm.py,sha256=0NvVocEHVlIAXnfKhiUW6PEbG_L7uBgE7NGJtaoJ-Rw,6136
|
3
|
-
divergent_beamsearch-0.1.1.dist-info/METADATA,sha256=dFlRtT8pvNDcUDZaac59zsLAWHB5M5maMkPO-DKFDGI,2826
|
4
|
-
divergent_beamsearch-0.1.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
-
divergent_beamsearch-0.1.1.dist-info/licenses/LICENCE,sha256=jDQOOFKJxgrQwcEyipwKcKzj5IX_paD_41c3iOjH3qw,1095
|
6
|
-
divergent_beamsearch-0.1.1.dist-info/RECORD,,
|
File without changes
|