divergent-beamsearch 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- divergent_beamsearch/__init__.py +1 -1
- divergent_beamsearch/algorithm.py +62 -3
- {divergent_beamsearch-0.1.2.dist-info → divergent_beamsearch-0.1.4.dist-info}/METADATA +1 -1
- divergent_beamsearch-0.1.4.dist-info/RECORD +6 -0
- divergent_beamsearch-0.1.2.dist-info/RECORD +0 -6
- {divergent_beamsearch-0.1.2.dist-info → divergent_beamsearch-0.1.4.dist-info}/WHEEL +0 -0
- {divergent_beamsearch-0.1.2.dist-info → divergent_beamsearch-0.1.4.dist-info}/licenses/LICENCE +0 -0
divergent_beamsearch/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
from .algorithm import divergent_beamsearch
|
1
|
+
from .algorithm import divergent_beamsearch, divergent_logprob
|
@@ -35,10 +35,12 @@ def apply_mask_tokens(pred : torch.Tensor, parsers_tokens):
|
|
35
35
|
return pred[~pred.isinf().all(dim=-1)]
|
36
36
|
|
37
37
|
|
38
|
-
def batched_inference_logits(model : GPT2LMHeadModel, input_ids : torch.Tensor, batch_size : int = 32) -> torch.Tensor:
|
38
|
+
def batched_inference_logits(model : GPT2LMHeadModel, input_ids : torch.Tensor, attention_mask : torch.Tensor | None = None, batch_size : int = 32) -> torch.Tensor:
|
39
39
|
logits = []
|
40
|
+
if attention_mask is None:
|
41
|
+
attention_mask = torch.ones_like(input_ids)
|
40
42
|
for i in range(0, input_ids.shape[0], batch_size):
|
41
|
-
logits.append(model(input_ids[i:i+batch_size]).logits)
|
43
|
+
logits.append(model(input_ids[i:i+batch_size], attention_mask=attention_mask[i:i+batch_size]).logits)
|
42
44
|
return torch.cat(logits, dim=0)
|
43
45
|
|
44
46
|
def select_mask(source : list, mask : list[bool]) -> list:
|
@@ -98,7 +100,7 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
|
|
98
100
|
for _ in range(max_length):
|
99
101
|
if len(input_ids_unfinished) == 0:
|
100
102
|
break
|
101
|
-
pred = batched_inference_logits(model, input_ids_unfinished.to(device), batch_size)[:, -1].cpu()
|
103
|
+
pred = batched_inference_logits(model, input_ids_unfinished.to(device), batch_size=batch_size)[:, -1].cpu()
|
102
104
|
parsers_tokens, can_end = get_parsers_tokens(parsers_unfinished)
|
103
105
|
logprobs = torch.log_softmax(pred, dim=-1)
|
104
106
|
logprobs_filtered = apply_mask_tokens(logprobs, parsers_tokens)
|
@@ -144,3 +146,60 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
|
|
144
146
|
solutions_finished = solutions_unfinished[order][:num_solutions]
|
145
147
|
|
146
148
|
return scores_finished, solutions_finished
|
149
|
+
|
150
|
+
|
151
|
+
def set_slice_row(x : torch.Tensor, slices : torch.IntTensor, value) -> torch.Tensor:
|
152
|
+
indices = [torch.arange(start, end) for start, end in slices]
|
153
|
+
for i in range(slices.size(0)):
|
154
|
+
x[i].index_fill_(0, indices[i], 0)
|
155
|
+
|
156
|
+
@torch.no_grad()
|
157
|
+
def divergent_logprob(input_ids : torch.Tensor, attention_mask : torch.Tensor | None, model : GPT2LMHeadModel, parsers : Parser | list[Parser] | None, batch_size=32, start : int | torch.IntTensor = None) -> torch.FloatTensor:
|
158
|
+
if start is None:
|
159
|
+
start = 0
|
160
|
+
if isinstance(start, int):
|
161
|
+
start = torch.tensor([start]*input_ids.shape[0])
|
162
|
+
assert start.shape[0] == input_ids.shape[0]
|
163
|
+
# -1 because next token offset
|
164
|
+
start = start - 1
|
165
|
+
|
166
|
+
if attention_mask is None:
|
167
|
+
attention_mask = torch.ones_like(input_ids)
|
168
|
+
|
169
|
+
logits = batched_inference_logits(model, input_ids, attention_mask, batch_size).cpu()
|
170
|
+
input_ids = input_ids.cpu()
|
171
|
+
attention_mask = attention_mask.cpu()
|
172
|
+
|
173
|
+
logsoftmax = torch.log_softmax(logits, dim=-1)
|
174
|
+
log_probs = torch.gather(
|
175
|
+
logsoftmax[:, :-1, :], 2, input_ids[:, 1:, None]
|
176
|
+
).squeeze(-1)
|
177
|
+
mask = attention_mask[:, 1:].cpu().clone()
|
178
|
+
|
179
|
+
input_len = attention_mask.sum(-1)
|
180
|
+
pos = torch.stack([torch.zeros_like(start), start], dim=-1)
|
181
|
+
pos_anti = pos.flip(1)
|
182
|
+
pos_anti[:, -1] = input_len
|
183
|
+
set_slice_row(mask, pos, 0)
|
184
|
+
vanilla_prob = (log_probs * mask).sum(-1)
|
185
|
+
if parsers is None:
|
186
|
+
parsers = AcceptEverythingParser(model.config.vocab_size)
|
187
|
+
if not isinstance(parsers, (tuple, list)):
|
188
|
+
parsers = [parsers.copy() for _ in range(len(input_ids))]
|
189
|
+
next_possible_tokens = []
|
190
|
+
for i, parser in enumerate(parsers):
|
191
|
+
# +1 because no next-token offset
|
192
|
+
start = pos_anti[i,0]+1
|
193
|
+
for input_id, att in zip(input_ids[i, start:].tolist(), attention_mask[i, start:].tolist()):
|
194
|
+
if not att:
|
195
|
+
break
|
196
|
+
parser.step(input_id)
|
197
|
+
next_tokens = list(parser.next())
|
198
|
+
try:
|
199
|
+
next_tokens.remove(end_symb)
|
200
|
+
except ValueError:
|
201
|
+
pass
|
202
|
+
next_possible_tokens.append(next_tokens)
|
203
|
+
last_token_log_probs = torch.stack([log1mexp(logsoftmax[i, input_len[i]-1, tokens].logsumexp(-1)).squeeze() for i, tokens in enumerate(next_possible_tokens)])
|
204
|
+
prob = vanilla_prob + last_token_log_probs
|
205
|
+
return prob
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: divergent-beamsearch
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.4
|
4
4
|
Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
|
5
5
|
License-File: LICENCE
|
6
6
|
Requires-Python: >=3.11
|
@@ -0,0 +1,6 @@
|
|
1
|
+
divergent_beamsearch/__init__.py,sha256=qrpVRoT3d-q1N9fJnzHI2X13e71LDY4-6eLOQ_gwCqQ,62
|
2
|
+
divergent_beamsearch/algorithm.py,sha256=d0xU4OWiCEa5icdXZHoV1P-eKYftYMHhfBZMEVNkRXQ,8649
|
3
|
+
divergent_beamsearch-0.1.4.dist-info/METADATA,sha256=f1nBA8_Q3a-PgtE6Z5YViFhWVRvuE3aLkPFQyin9274,2826
|
4
|
+
divergent_beamsearch-0.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
+
divergent_beamsearch-0.1.4.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
|
6
|
+
divergent_beamsearch-0.1.4.dist-info/RECORD,,
|
@@ -1,6 +0,0 @@
|
|
1
|
-
divergent_beamsearch/__init__.py,sha256=Z2R1pkj4EEHMKWVZX0upeE_Jtfb6joxgYHuRNxWc8Zo,43
|
2
|
-
divergent_beamsearch/algorithm.py,sha256=w6aLDOnLwLabmHHOMCEx1Y8P8yaHkFEJxMGNw6f7RsU,6115
|
3
|
-
divergent_beamsearch-0.1.2.dist-info/METADATA,sha256=fjqa8W8RpRSZOFV-3_D28o2353p54mB7bhEUrsRSoxw,2826
|
4
|
-
divergent_beamsearch-0.1.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
-
divergent_beamsearch-0.1.2.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
|
6
|
-
divergent_beamsearch-0.1.2.dist-info/RECORD,,
|
File without changes
|
{divergent_beamsearch-0.1.2.dist-info → divergent_beamsearch-0.1.4.dist-info}/licenses/LICENCE
RENAMED
File without changes
|