divergent-beamsearch 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- divergent_beamsearch/__init__.py +1 -1
- divergent_beamsearch/algorithm.py +62 -3
- {divergent_beamsearch-0.1.2.dist-info → divergent_beamsearch-0.1.4.dist-info}/METADATA +1 -1
- divergent_beamsearch-0.1.4.dist-info/RECORD +6 -0
- divergent_beamsearch-0.1.2.dist-info/RECORD +0 -6
- {divergent_beamsearch-0.1.2.dist-info → divergent_beamsearch-0.1.4.dist-info}/WHEEL +0 -0
- {divergent_beamsearch-0.1.2.dist-info → divergent_beamsearch-0.1.4.dist-info}/licenses/LICENCE +0 -0
divergent_beamsearch/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
from .algorithm import divergent_beamsearch
|
1
|
+
from .algorithm import divergent_beamsearch, divergent_logprob
|
@@ -35,10 +35,12 @@ def apply_mask_tokens(pred : torch.Tensor, parsers_tokens):
|
|
35
35
|
return pred[~pred.isinf().all(dim=-1)]
|
36
36
|
|
37
37
|
|
38
|
-
def batched_inference_logits(model : GPT2LMHeadModel, input_ids : torch.Tensor, batch_size : int = 32) -> torch.Tensor:
|
38
|
+
def batched_inference_logits(model : GPT2LMHeadModel, input_ids : torch.Tensor, attention_mask : torch.Tensor | None = None, batch_size : int = 32) -> torch.Tensor:
|
39
39
|
logits = []
|
40
|
+
if attention_mask is None:
|
41
|
+
attention_mask = torch.ones_like(input_ids)
|
40
42
|
for i in range(0, input_ids.shape[0], batch_size):
|
41
|
-
logits.append(model(input_ids[i:i+batch_size]).logits)
|
43
|
+
logits.append(model(input_ids[i:i+batch_size], attention_mask=attention_mask[i:i+batch_size]).logits)
|
42
44
|
return torch.cat(logits, dim=0)
|
43
45
|
|
44
46
|
def select_mask(source : list, mask : list[bool]) -> list:
|
@@ -98,7 +100,7 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
|
|
98
100
|
for _ in range(max_length):
|
99
101
|
if len(input_ids_unfinished) == 0:
|
100
102
|
break
|
101
|
-
pred = batched_inference_logits(model, input_ids_unfinished.to(device), batch_size)[:, -1].cpu()
|
103
|
+
pred = batched_inference_logits(model, input_ids_unfinished.to(device), batch_size=batch_size)[:, -1].cpu()
|
102
104
|
parsers_tokens, can_end = get_parsers_tokens(parsers_unfinished)
|
103
105
|
logprobs = torch.log_softmax(pred, dim=-1)
|
104
106
|
logprobs_filtered = apply_mask_tokens(logprobs, parsers_tokens)
|
@@ -144,3 +146,60 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
|
|
144
146
|
solutions_finished = solutions_unfinished[order][:num_solutions]
|
145
147
|
|
146
148
|
return scores_finished, solutions_finished
|
149
|
+
|
150
|
+
|
151
|
+
def set_slice_row(x : torch.Tensor, slices : torch.IntTensor, value) -> torch.Tensor:
|
152
|
+
indices = [torch.arange(start, end) for start, end in slices]
|
153
|
+
for i in range(slices.size(0)):
|
154
|
+
x[i].index_fill_(0, indices[i], 0)
|
155
|
+
|
156
|
+
@torch.no_grad()
|
157
|
+
def divergent_logprob(input_ids : torch.Tensor, attention_mask : torch.Tensor | None, model : GPT2LMHeadModel, parsers : Parser | list[Parser] | None, batch_size=32, start : int | torch.IntTensor = None) -> torch.FloatTensor:
|
158
|
+
if start is None:
|
159
|
+
start = 0
|
160
|
+
if isinstance(start, int):
|
161
|
+
start = torch.tensor([start]*input_ids.shape[0])
|
162
|
+
assert start.shape[0] == input_ids.shape[0]
|
163
|
+
# -1 because next token offset
|
164
|
+
start = start - 1
|
165
|
+
|
166
|
+
if attention_mask is None:
|
167
|
+
attention_mask = torch.ones_like(input_ids)
|
168
|
+
|
169
|
+
logits = batched_inference_logits(model, input_ids, attention_mask, batch_size).cpu()
|
170
|
+
input_ids = input_ids.cpu()
|
171
|
+
attention_mask = attention_mask.cpu()
|
172
|
+
|
173
|
+
logsoftmax = torch.log_softmax(logits, dim=-1)
|
174
|
+
log_probs = torch.gather(
|
175
|
+
logsoftmax[:, :-1, :], 2, input_ids[:, 1:, None]
|
176
|
+
).squeeze(-1)
|
177
|
+
mask = attention_mask[:, 1:].cpu().clone()
|
178
|
+
|
179
|
+
input_len = attention_mask.sum(-1)
|
180
|
+
pos = torch.stack([torch.zeros_like(start), start], dim=-1)
|
181
|
+
pos_anti = pos.flip(1)
|
182
|
+
pos_anti[:, -1] = input_len
|
183
|
+
set_slice_row(mask, pos, 0)
|
184
|
+
vanilla_prob = (log_probs * mask).sum(-1)
|
185
|
+
if parsers is None:
|
186
|
+
parsers = AcceptEverythingParser(model.config.vocab_size)
|
187
|
+
if not isinstance(parsers, (tuple, list)):
|
188
|
+
parsers = [parsers.copy() for _ in range(len(input_ids))]
|
189
|
+
next_possible_tokens = []
|
190
|
+
for i, parser in enumerate(parsers):
|
191
|
+
# +1 because no next-token offset
|
192
|
+
start = pos_anti[i,0]+1
|
193
|
+
for input_id, att in zip(input_ids[i, start:].tolist(), attention_mask[i, start:].tolist()):
|
194
|
+
if not att:
|
195
|
+
break
|
196
|
+
parser.step(input_id)
|
197
|
+
next_tokens = list(parser.next())
|
198
|
+
try:
|
199
|
+
next_tokens.remove(end_symb)
|
200
|
+
except ValueError:
|
201
|
+
pass
|
202
|
+
next_possible_tokens.append(next_tokens)
|
203
|
+
last_token_log_probs = torch.stack([log1mexp(logsoftmax[i, input_len[i]-1, tokens].logsumexp(-1)).squeeze() for i, tokens in enumerate(next_possible_tokens)])
|
204
|
+
prob = vanilla_prob + last_token_log_probs
|
205
|
+
return prob
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: divergent-beamsearch
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.4
|
4
4
|
Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
|
5
5
|
License-File: LICENCE
|
6
6
|
Requires-Python: >=3.11
|
@@ -0,0 +1,6 @@
|
|
1
|
+
divergent_beamsearch/__init__.py,sha256=qrpVRoT3d-q1N9fJnzHI2X13e71LDY4-6eLOQ_gwCqQ,62
|
2
|
+
divergent_beamsearch/algorithm.py,sha256=d0xU4OWiCEa5icdXZHoV1P-eKYftYMHhfBZMEVNkRXQ,8649
|
3
|
+
divergent_beamsearch-0.1.4.dist-info/METADATA,sha256=f1nBA8_Q3a-PgtE6Z5YViFhWVRvuE3aLkPFQyin9274,2826
|
4
|
+
divergent_beamsearch-0.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
+
divergent_beamsearch-0.1.4.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
|
6
|
+
divergent_beamsearch-0.1.4.dist-info/RECORD,,
|
@@ -1,6 +0,0 @@
|
|
1
|
-
divergent_beamsearch/__init__.py,sha256=Z2R1pkj4EEHMKWVZX0upeE_Jtfb6joxgYHuRNxWc8Zo,43
|
2
|
-
divergent_beamsearch/algorithm.py,sha256=w6aLDOnLwLabmHHOMCEx1Y8P8yaHkFEJxMGNw6f7RsU,6115
|
3
|
-
divergent_beamsearch-0.1.2.dist-info/METADATA,sha256=fjqa8W8RpRSZOFV-3_D28o2353p54mB7bhEUrsRSoxw,2826
|
4
|
-
divergent_beamsearch-0.1.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
-
divergent_beamsearch-0.1.2.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
|
6
|
-
divergent_beamsearch-0.1.2.dist-info/RECORD,,
|
File without changes
|
{divergent_beamsearch-0.1.2.dist-info → divergent_beamsearch-0.1.4.dist-info}/licenses/LICENCE
RENAMED
File without changes
|