divergent-beamsearch 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- divergent_beamsearch/algorithm.py +26 -6
- {divergent_beamsearch-0.1.0.dist-info → divergent_beamsearch-0.1.1.dist-info}/METADATA +1 -1
- divergent_beamsearch-0.1.1.dist-info/RECORD +6 -0
- divergent_beamsearch-0.1.0.dist-info/RECORD +0 -6
- {divergent_beamsearch-0.1.0.dist-info → divergent_beamsearch-0.1.1.dist-info}/WHEEL +0 -0
- {divergent_beamsearch-0.1.0.dist-info → divergent_beamsearch-0.1.1.dist-info}/licenses/LICENCE +0 -0
@@ -47,12 +47,29 @@ def log1mexp(x: torch.Tensor) -> torch.Tensor:
|
|
47
47
|
(-x.exp()).log1p(),
|
48
48
|
)
|
49
49
|
|
50
|
+
class AcceptEverythingParser:
|
51
|
+
def __init__(self, vocab_size : int):
|
52
|
+
self.vocab_size = vocab_size
|
53
|
+
self.tokens = tuple(range(vocab_size))
|
54
|
+
|
55
|
+
def step(self, token):
|
56
|
+
pass
|
57
|
+
|
58
|
+
def next(self):
|
59
|
+
return self.tokens
|
60
|
+
|
61
|
+
def copy(self):
|
62
|
+
return self
|
63
|
+
|
50
64
|
@torch.no_grad()
|
51
65
|
def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, multi_choices_parser : MultiChoicesParser, pad_token_id : int, batch_size=32, num_solutions = None) -> tuple[torch.Tensor, torch.Tensor]:
|
52
66
|
assert input_ids.shape[0] == 1, "Batch size must be 1"
|
53
67
|
|
54
68
|
if num_solutions is None:
|
55
69
|
num_solutions = beam_size
|
70
|
+
vanilla = multi_choices_parser is None
|
71
|
+
if vanilla:
|
72
|
+
multi_choices_parser = AcceptEverythingParser(model.config.vocab_size)
|
56
73
|
|
57
74
|
parsers_unfinished = [multi_choices_parser]
|
58
75
|
scores_finished = torch.tensor([], dtype=torch.float)
|
@@ -73,9 +90,10 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
|
|
73
90
|
logprobs_filtered = apply_mask_tokens(logprobs, parsers_tokens)
|
74
91
|
if len(logprobs_filtered):
|
75
92
|
topk = torch.topk(logprobs_filtered, beam_size, dim=-1) # shape (batch_size, beam_size)
|
76
|
-
|
93
|
+
values = topk.values + scores_unfinished.unsqueeze(-1)
|
94
|
+
topk_global = values.flatten().topk(beam_size)
|
77
95
|
best_tokens_row = topk_global.indices // beam_size
|
78
|
-
best_tokens, best_tokens_logprobs = topk.indices[best_tokens_row, topk_global.indices % beam_size], topk_global.
|
96
|
+
best_tokens, best_tokens_logprobs = topk.indices[best_tokens_row, topk_global.indices % beam_size], topk.values[best_tokens_row, topk_global.indices % beam_size]
|
79
97
|
notinf = ~best_tokens_logprobs.isinf()
|
80
98
|
best_tokens, best_tokens_row, best_tokens_logprobs = best_tokens[notinf], best_tokens_row[notinf], best_tokens_logprobs[notinf]
|
81
99
|
else:
|
@@ -104,9 +122,11 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
|
|
104
122
|
parsers_unfinished = [parsers_unfinished[row].copy() for row in best_tokens_row]
|
105
123
|
for parser, token in zip(parsers_unfinished, best_tokens.tolist()):
|
106
124
|
parser.step(token)
|
125
|
+
|
126
|
+
# Special case of vanilla beam search where all answers are valid
|
127
|
+
if vanilla:
|
128
|
+
order = scores_unfinished.argsort(descending=True)
|
129
|
+
scores_finished = scores_unfinished[order][:num_solutions]
|
130
|
+
solutions_finished = solutions_unfinished[order][:num_solutions]
|
107
131
|
|
108
132
|
return scores_finished, solutions_finished
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: divergent-beamsearch
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.1
|
4
4
|
Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
|
5
5
|
License-File: LICENCE
|
6
6
|
Requires-Python: >=3.11
|
@@ -0,0 +1,6 @@
|
|
1
|
+
divergent_beamsearch/__init__.py,sha256=Z2R1pkj4EEHMKWVZX0upeE_Jtfb6joxgYHuRNxWc8Zo,43
|
2
|
+
divergent_beamsearch/algorithm.py,sha256=0NvVocEHVlIAXnfKhiUW6PEbG_L7uBgE7NGJtaoJ-Rw,6136
|
3
|
+
divergent_beamsearch-0.1.1.dist-info/METADATA,sha256=dFlRtT8pvNDcUDZaac59zsLAWHB5M5maMkPO-DKFDGI,2826
|
4
|
+
divergent_beamsearch-0.1.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
+
divergent_beamsearch-0.1.1.dist-info/licenses/LICENCE,sha256=jDQOOFKJxgrQwcEyipwKcKzj5IX_paD_41c3iOjH3qw,1095
|
6
|
+
divergent_beamsearch-0.1.1.dist-info/RECORD,,
|
@@ -1,6 +0,0 @@
|
|
1
|
-
divergent_beamsearch/__init__.py,sha256=Z2R1pkj4EEHMKWVZX0upeE_Jtfb6joxgYHuRNxWc8Zo,43
|
2
|
-
divergent_beamsearch/algorithm.py,sha256=6cWp6XHepSn1rjqQFkASxd8k3OEUarKNAiqPKfrA78k,5324
|
3
|
-
divergent_beamsearch-0.1.0.dist-info/METADATA,sha256=UCkp3rgFZ89kmwFVy_N_dEy45NTGN_yFhf-J6WCCR4U,2826
|
4
|
-
divergent_beamsearch-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
-
divergent_beamsearch-0.1.0.dist-info/licenses/LICENCE,sha256=jDQOOFKJxgrQwcEyipwKcKzj5IX_paD_41c3iOjH3qw,1095
|
6
|
-
divergent_beamsearch-0.1.0.dist-info/RECORD,,
|
File without changes
|
{divergent_beamsearch-0.1.0.dist-info → divergent_beamsearch-0.1.1.dist-info}/licenses/LICENCE
RENAMED
File without changes
|