divergent-beamsearch 0.1.2__tar.gz → 0.1.3__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {divergent_beamsearch-0.1.2 → divergent_beamsearch-0.1.3}/PKG-INFO +1 -1
- {divergent_beamsearch-0.1.2 → divergent_beamsearch-0.1.3}/pyproject.toml +1 -1
- {divergent_beamsearch-0.1.2 → divergent_beamsearch-0.1.3}/src/divergent_beamsearch/algorithm.py +62 -3
- {divergent_beamsearch-0.1.2 → divergent_beamsearch-0.1.3}/tests/test_beamsearch.py +35 -1
- {divergent_beamsearch-0.1.2 → divergent_beamsearch-0.1.3}/.gitignore +0 -0
- {divergent_beamsearch-0.1.2 → divergent_beamsearch-0.1.3}/.python-version +0 -0
- {divergent_beamsearch-0.1.2 → divergent_beamsearch-0.1.3}/LICENCE +0 -0
- {divergent_beamsearch-0.1.2 → divergent_beamsearch-0.1.3}/README.md +0 -0
- {divergent_beamsearch-0.1.2 → divergent_beamsearch-0.1.3}/src/divergent_beamsearch/__init__.py +0 -0
- {divergent_beamsearch-0.1.2 → divergent_beamsearch-0.1.3}/uv.lock +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: divergent-beamsearch
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.3
|
4
4
|
Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
|
5
5
|
License-File: LICENCE
|
6
6
|
Requires-Python: >=3.11
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "divergent-beamsearch"
|
3
|
-
version = "0.1.
|
3
|
+
version = "0.1.3"
|
4
4
|
description = "A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject."
|
5
5
|
readme = "README.md"
|
6
6
|
requires-python = ">=3.11"
|
{divergent_beamsearch-0.1.2 → divergent_beamsearch-0.1.3}/src/divergent_beamsearch/algorithm.py
RENAMED
@@ -35,10 +35,12 @@ def apply_mask_tokens(pred : torch.Tensor, parsers_tokens):
|
|
35
35
|
return pred[~pred.isinf().all(dim=-1)]
|
36
36
|
|
37
37
|
|
38
|
-
def batched_inference_logits(model : GPT2LMHeadModel, input_ids : torch.Tensor, batch_size : int = 32) -> torch.Tensor:
|
38
|
+
def batched_inference_logits(model : GPT2LMHeadModel, input_ids : torch.Tensor, attention_mask : torch.Tensor | None = None, batch_size : int = 32) -> torch.Tensor:
|
39
39
|
logits = []
|
40
|
+
if attention_mask is None:
|
41
|
+
attention_mask = torch.ones_like(input_ids)
|
40
42
|
for i in range(0, input_ids.shape[0], batch_size):
|
41
|
-
logits.append(model(input_ids[i:i+batch_size]).logits)
|
43
|
+
logits.append(model(input_ids[i:i+batch_size], attention_mask=attention_mask[i:i+batch_size]).logits)
|
42
44
|
return torch.cat(logits, dim=0)
|
43
45
|
|
44
46
|
def select_mask(source : list, mask : list[bool]) -> list:
|
@@ -98,7 +100,7 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
|
|
98
100
|
for _ in range(max_length):
|
99
101
|
if len(input_ids_unfinished) == 0:
|
100
102
|
break
|
101
|
-
pred = batched_inference_logits(model, input_ids_unfinished.to(device), batch_size)[:, -1].cpu()
|
103
|
+
pred = batched_inference_logits(model, input_ids_unfinished.to(device), batch_size=batch_size)[:, -1].cpu()
|
102
104
|
parsers_tokens, can_end = get_parsers_tokens(parsers_unfinished)
|
103
105
|
logprobs = torch.log_softmax(pred, dim=-1)
|
104
106
|
logprobs_filtered = apply_mask_tokens(logprobs, parsers_tokens)
|
@@ -144,3 +146,60 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
|
|
144
146
|
solutions_finished = solutions_unfinished[order][:num_solutions]
|
145
147
|
|
146
148
|
return scores_finished, solutions_finished
|
149
|
+
|
150
|
+
|
151
|
+
def set_slice_row(x : torch.Tensor, slices : torch.IntTensor, value) -> torch.Tensor:
|
152
|
+
indices = [torch.arange(start, end) for start, end in slices]
|
153
|
+
for i in range(slices.size(0)):
|
154
|
+
x[i].index_fill_(0, indices[i], 0)
|
155
|
+
|
156
|
+
@torch.no_grad()
|
157
|
+
def divergent_logprob(input_ids : torch.Tensor, attention_mask : torch.Tensor | None, model : GPT2LMHeadModel, parsers : Parser | list[Parser] | None, batch_size=32, start : int | torch.IntTensor = None) -> torch.FloatTensor:
|
158
|
+
if start is None:
|
159
|
+
start = 0
|
160
|
+
if isinstance(start, int):
|
161
|
+
start = torch.tensor([start]*input_ids.shape[0])
|
162
|
+
assert start.shape[0] == input_ids.shape[0]
|
163
|
+
# -1 because next token offset
|
164
|
+
start = start - 1
|
165
|
+
|
166
|
+
if attention_mask is None:
|
167
|
+
attention_mask = torch.ones_like(input_ids)
|
168
|
+
|
169
|
+
logits = batched_inference_logits(model, input_ids, attention_mask, batch_size).cpu()
|
170
|
+
input_ids = input_ids.cpu()
|
171
|
+
attention_mask = attention_mask.cpu()
|
172
|
+
|
173
|
+
logsoftmax = torch.log_softmax(logits, dim=-1)
|
174
|
+
log_probs = torch.gather(
|
175
|
+
logsoftmax[:, :-1, :], 2, input_ids[:, 1:, None]
|
176
|
+
).squeeze(-1)
|
177
|
+
mask = attention_mask[:, 1:].cpu().clone()
|
178
|
+
|
179
|
+
input_len = attention_mask.sum(-1)
|
180
|
+
pos = torch.stack([torch.zeros_like(start), start], dim=-1)
|
181
|
+
pos_anti = pos.flip(1)
|
182
|
+
pos_anti[:, -1] = input_len
|
183
|
+
set_slice_row(mask, pos, 0)
|
184
|
+
vanilla_prob = (log_probs * mask).sum(-1)
|
185
|
+
if parsers is None:
|
186
|
+
parsers = AcceptEverythingParser(model.config.vocab_size)
|
187
|
+
if not isinstance(parsers, (tuple, list)):
|
188
|
+
parsers = [parsers.copy() for _ in range(len(input_ids))]
|
189
|
+
next_possible_tokens = []
|
190
|
+
for i, parser in enumerate(parsers):
|
191
|
+
# +1 because no next-token offset
|
192
|
+
start = pos_anti[i,0]+1
|
193
|
+
for input_id, att in zip(input_ids[i, start:].tolist(), attention_mask[i, start:].tolist()):
|
194
|
+
if not att:
|
195
|
+
break
|
196
|
+
parser.step(input_id)
|
197
|
+
next_tokens = list(parser.next())
|
198
|
+
try:
|
199
|
+
next_tokens.remove(end_symb)
|
200
|
+
except ValueError:
|
201
|
+
pass
|
202
|
+
next_possible_tokens.append(next_tokens)
|
203
|
+
last_token_log_probs = torch.stack([log1mexp(logsoftmax[i, input_len[i]-1, tokens].logsumexp(-1)).squeeze() for i, tokens in enumerate(next_possible_tokens)])
|
204
|
+
prob = vanilla_prob + last_token_log_probs
|
205
|
+
return prob
|
@@ -3,7 +3,7 @@ import pytest
|
|
3
3
|
import torch
|
4
4
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
5
5
|
from multi_choices_parser import MultiChoicesParser
|
6
|
-
from divergent_beamsearch.algorithm import divergent_beamsearch, log1mexp
|
6
|
+
from divergent_beamsearch.algorithm import divergent_beamsearch, divergent_logprob, log1mexp
|
7
7
|
from multi_choices_parser import MultiChoicesParser
|
8
8
|
|
9
9
|
@pytest.fixture
|
@@ -46,6 +46,40 @@ def test_divergent_beamsearch(model_and_tokenizer, device):
|
|
46
46
|
assert scores[0] == logprob_paris + log1mexp(logprob_hilton), "Beam search did not return the expected score"
|
47
47
|
assert scores[1] == logprob_paris_hilton, "Beam search did not return the expected score"
|
48
48
|
|
49
|
+
@pytest.mark.parametrize("device", ['cpu', 'cuda'])
|
50
|
+
def test_divergent_logprob(model_and_tokenizer, device):
|
51
|
+
if device == 'cuda' and not torch.cuda.is_available():
|
52
|
+
pytest.skip("CUDA is not available on this machine.")
|
53
|
+
model, tokenizer = model_and_tokenizer
|
54
|
+
model.to(device)
|
55
|
+
prompts = [
|
56
|
+
"The capital of France is Paris",
|
57
|
+
"The top model Paris Hilton"
|
58
|
+
]
|
59
|
+
tokenizer.pad_token = tokenizer.eos_token
|
60
|
+
inp = tokenizer(prompts, return_tensors="pt", padding=True)
|
61
|
+
input_ids = inp.input_ids.to(device)
|
62
|
+
attention_mask = inp.attention_mask.to(device)
|
63
|
+
|
64
|
+
possible_answers = [' Paris', ' Paris Hilton']
|
65
|
+
tokenized_answers = tokenizer(possible_answers).input_ids
|
66
|
+
multi_choices_parser = MultiChoicesParser([tokenized_answers])
|
67
|
+
|
68
|
+
input_len = attention_mask.sum(-1).cpu()
|
69
|
+
probs = divergent_logprob(input_ids, attention_mask, model, multi_choices_parser, start=input_len - torch.tensor([1,2]))
|
70
|
+
|
71
|
+
input_ids_1st = tokenizer("The capital of France is Paris Hilton", return_tensors='pt').input_ids.to(device)
|
72
|
+
logprobs_1st = model(input_ids_1st).logits.cpu().log_softmax(dim=-1)
|
73
|
+
logprob_paris = logprobs_1st[0, input_ids_1st.shape[1]-3, tokenized_answers[1][0]] # P(Paris | The capital of France is)
|
74
|
+
logprob_hilton = logprobs_1st[0, input_ids_1st.shape[1]-2, tokenized_answers[1][1]] # P(Hilton | The capital of France is Paris)
|
75
|
+
|
76
|
+
input_ids_2nd = tokenizer("The top model Paris Hilton", return_tensors='pt').input_ids.to(device)
|
77
|
+
logprobs_2nd = model(input_ids_2nd).logits.cpu().log_softmax(dim=-1)
|
78
|
+
logprob_paris_hilton = logprobs_2nd[0, -3, tokenized_answers[1][0]] + logprobs_2nd[0, -2, tokenized_answers[1][1]] # P(Paris Hilton | The top model)
|
79
|
+
|
80
|
+
assert torch.isclose(probs[0], logprob_paris + log1mexp(logprob_hilton)), "P_divergent(Paris | The capital of France is) is incorrect"
|
81
|
+
assert torch.isclose(probs[1], logprob_paris_hilton), "P_divergent(Paris Hilton | The top model) is incorrect"
|
82
|
+
|
49
83
|
@pytest.mark.parametrize("device", ['cpu', 'cuda'])
|
50
84
|
def test_vanilla_beamsearch(model_and_tokenizer, device):
|
51
85
|
if device == 'cuda' and not torch.cuda.is_available():
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{divergent_beamsearch-0.1.2 → divergent_beamsearch-0.1.3}/src/divergent_beamsearch/__init__.py
RENAMED
File without changes
|
File without changes
|