divergent-beamsearch 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- divergent_beamsearch/algorithm.py +17 -7
- {divergent_beamsearch-0.1.5.dist-info → divergent_beamsearch-0.1.7.dist-info}/METADATA +2 -2
- divergent_beamsearch-0.1.7.dist-info/RECORD +6 -0
- divergent_beamsearch-0.1.5.dist-info/RECORD +0 -6
- {divergent_beamsearch-0.1.5.dist-info → divergent_beamsearch-0.1.7.dist-info}/WHEEL +0 -0
- {divergent_beamsearch-0.1.5.dist-info → divergent_beamsearch-0.1.7.dist-info}/licenses/LICENCE +0 -0
@@ -1,7 +1,7 @@
|
|
1
1
|
import math
|
2
2
|
import torch
|
3
3
|
from transformers import GPT2LMHeadModel
|
4
|
-
from multi_choices_parser import
|
4
|
+
from multi_choices_parser import DEFAULT_END_SYMB
|
5
5
|
|
6
6
|
|
7
7
|
class Parser:
|
@@ -19,10 +19,10 @@ def get_parsers_tokens(parsers : list[Parser], end_symb) -> tuple[list, list[int
|
|
19
19
|
can_end = []
|
20
20
|
for parser in parsers:
|
21
21
|
tokens = list(parser.next())
|
22
|
-
|
23
|
-
can_end.append(True)
|
22
|
+
try:
|
24
23
|
tokens.remove(end_symb)
|
25
|
-
|
24
|
+
can_end.append(True)
|
25
|
+
except ValueError:
|
26
26
|
can_end.append(False)
|
27
27
|
parsers_tokens.append(tokens)
|
28
28
|
return parsers_tokens, can_end
|
@@ -76,8 +76,14 @@ class AcceptEverythingParser(Parser):
|
|
76
76
|
def copy(self):
|
77
77
|
return self
|
78
78
|
|
79
|
+
def index_reduce_lists(x : torch.Tensor, indices : list[list[int]], reduce_func=torch.sum) -> torch.Tensor:
|
80
|
+
values = []
|
81
|
+
for i, index in enumerate(indices):
|
82
|
+
values.append(reduce_func(x[i, index], dim=-1))
|
83
|
+
return torch.tensor(values, dtype=x.dtype, device=x.device, requires_grad=x.requires_grad)
|
84
|
+
|
79
85
|
@torch.no_grad()
|
80
|
-
def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, parser : Parser, pad_token_id : int, batch_size=32, num_solutions = None, end_symb=
|
86
|
+
def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, parser : Parser, pad_token_id : int, batch_size=32, num_solutions = None, end_symb=DEFAULT_END_SYMB) -> tuple[torch.Tensor, torch.Tensor]:
|
81
87
|
assert input_ids.shape[0] == 1, "Batch size must be 1"
|
82
88
|
device = input_ids.device
|
83
89
|
input_ids = input_ids.cpu()
|
@@ -120,7 +126,8 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
|
|
120
126
|
|
121
127
|
scores_finished_current = scores_unfinished[can_end]
|
122
128
|
solutions_finished_current = solutions_unfinished[can_end]
|
123
|
-
|
129
|
+
logprob_other_ans = index_reduce_lists(logprobs[can_end], select_mask(parsers_tokens, can_end), reduce_func=torch.logsumexp).squeeze(-1)
|
130
|
+
scores_finished_current = scores_finished_current + log1mexp(logprob_other_ans)
|
124
131
|
scores_finished = torch.cat([scores_finished, scores_finished_current])
|
125
132
|
if len(solutions_finished_current):
|
126
133
|
pad = torch.full((len(scores_finished_current), solutions_finished_current.shape[1] - solutions_finished.shape[1]), pad_token_id, dtype=torch.long)
|
@@ -140,6 +147,7 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
|
|
140
147
|
parser.step(token)
|
141
148
|
|
142
149
|
# Special case of vanilla beam search where all answers are valid
|
150
|
+
# Warning : In this case model will not stop on end_of_sentence token
|
143
151
|
if vanilla:
|
144
152
|
order = scores_unfinished.argsort(descending=True)
|
145
153
|
scores_finished = scores_unfinished[order][:num_solutions]
|
@@ -154,7 +162,9 @@ def set_slice_row(x : torch.Tensor, slices : torch.IntTensor, value) -> torch.Te
|
|
154
162
|
x[i].index_fill_(0, indices[i], 0)
|
155
163
|
|
156
164
|
@torch.no_grad()
|
157
|
-
def divergent_logprob(input_ids : torch.Tensor, attention_mask : torch.Tensor | None, model : GPT2LMHeadModel,
|
165
|
+
def divergent_logprob(input_ids : torch.Tensor, attention_mask : torch.Tensor | None, model : GPT2LMHeadModel,
|
166
|
+
parsers : Parser | list[Parser] | None, batch_size=32,
|
167
|
+
start : int | torch.IntTensor = None, end_symb=DEFAULT_END_SYMB) -> torch.FloatTensor:
|
158
168
|
if start is None:
|
159
169
|
start = 0
|
160
170
|
if isinstance(start, int):
|
@@ -1,10 +1,10 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: divergent-beamsearch
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.7
|
4
4
|
Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
|
5
5
|
License-File: LICENCE
|
6
6
|
Requires-Python: >=3.11
|
7
|
-
Requires-Dist: multi-choices-parser>=0.9.
|
7
|
+
Requires-Dist: multi-choices-parser>=0.9.61
|
8
8
|
Requires-Dist: torch>=2.0.0
|
9
9
|
Requires-Dist: transformers>=4.47.1
|
10
10
|
Description-Content-Type: text/markdown
|
@@ -0,0 +1,6 @@
|
|
1
|
+
divergent_beamsearch/__init__.py,sha256=qrpVRoT3d-q1N9fJnzHI2X13e71LDY4-6eLOQ_gwCqQ,62
|
2
|
+
divergent_beamsearch/algorithm.py,sha256=TUm2pbFhR0DqfGDm1fqQXqojNCAkFRmuvj4jbFCxwHc,9228
|
3
|
+
divergent_beamsearch-0.1.7.dist-info/METADATA,sha256=JWuN6f2YjjOXoYxAFzR7vmVYwPyL2HDXI7huY67gAmU,2826
|
4
|
+
divergent_beamsearch-0.1.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
+
divergent_beamsearch-0.1.7.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
|
6
|
+
divergent_beamsearch-0.1.7.dist-info/RECORD,,
|
@@ -1,6 +0,0 @@
|
|
1
|
-
divergent_beamsearch/__init__.py,sha256=qrpVRoT3d-q1N9fJnzHI2X13e71LDY4-6eLOQ_gwCqQ,62
|
2
|
-
divergent_beamsearch/algorithm.py,sha256=b44kA2_M0HfOSC6LYQmu7W_JnGAl1u9Sz_91jMjvWg0,8688
|
3
|
-
divergent_beamsearch-0.1.5.dist-info/METADATA,sha256=sRitbfgDp8YqLXZzZmZw7Nd1pJWv8T4HJoSqjhO4ztI,2826
|
4
|
-
divergent_beamsearch-0.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
-
divergent_beamsearch-0.1.5.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
|
6
|
-
divergent_beamsearch-0.1.5.dist-info/RECORD,,
|
File without changes
|
{divergent_beamsearch-0.1.5.dist-info → divergent_beamsearch-0.1.7.dist-info}/licenses/LICENCE
RENAMED
File without changes
|