divergent-beamsearch 0.2.0__tar.gz → 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {divergent_beamsearch-0.2.0 → divergent_beamsearch-0.2.2}/PKG-INFO +2 -3
- {divergent_beamsearch-0.2.0 → divergent_beamsearch-0.2.2}/pyproject.toml +3 -3
- {divergent_beamsearch-0.2.0 → divergent_beamsearch-0.2.2}/src/divergent_beamsearch/algorithm.py +15 -7
- {divergent_beamsearch-0.2.0 → divergent_beamsearch-0.2.2}/tests/test_beamsearch.py +4 -3
- {divergent_beamsearch-0.2.0 → divergent_beamsearch-0.2.2}/uv.lock +19 -10
- {divergent_beamsearch-0.2.0 → divergent_beamsearch-0.2.2}/.gitignore +0 -0
- {divergent_beamsearch-0.2.0 → divergent_beamsearch-0.2.2}/.python-version +0 -0
- {divergent_beamsearch-0.2.0 → divergent_beamsearch-0.2.2}/LICENCE +0 -0
- {divergent_beamsearch-0.2.0 → divergent_beamsearch-0.2.2}/README.md +0 -0
- {divergent_beamsearch-0.2.0 → divergent_beamsearch-0.2.2}/src/divergent_beamsearch/__init__.py +0 -0
@@ -1,12 +1,11 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: divergent-beamsearch
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.2
|
4
4
|
Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
|
5
5
|
License-File: LICENCE
|
6
6
|
Requires-Python: >=3.11
|
7
|
-
Requires-Dist: multi-choices-parser>=0.
|
7
|
+
Requires-Dist: multi-choices-parser>=0.10.0
|
8
8
|
Requires-Dist: torch>=2.0.0
|
9
|
-
Requires-Dist: transformers>=4.47.1
|
10
9
|
Description-Content-Type: text/markdown
|
11
10
|
|
12
11
|
# Divergent Beam Search
|
@@ -1,18 +1,18 @@
|
|
1
1
|
[project]
|
2
2
|
name = "divergent-beamsearch"
|
3
|
-
version = "0.2.
|
3
|
+
version = "0.2.2"
|
4
4
|
description = "A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject."
|
5
5
|
readme = "README.md"
|
6
6
|
requires-python = ">=3.11"
|
7
7
|
dependencies = [
|
8
|
-
"multi-choices-parser>=0.
|
8
|
+
"multi-choices-parser>=0.10.0",
|
9
9
|
"torch>=2.0.0",
|
10
|
-
"transformers>=4.47.1",
|
11
10
|
]
|
12
11
|
|
13
12
|
[dependency-groups]
|
14
13
|
dev = [
|
15
14
|
"pytest>=8.3.4",
|
15
|
+
"transformers>=4.47.1"
|
16
16
|
]
|
17
17
|
|
18
18
|
[build-system]
|
{divergent_beamsearch-0.2.0 → divergent_beamsearch-0.2.2}/src/divergent_beamsearch/algorithm.py
RENAMED
@@ -1,6 +1,10 @@
|
|
1
1
|
import math
|
2
|
+
import multi_choices_parser
|
2
3
|
import torch
|
3
|
-
|
4
|
+
try:
|
5
|
+
from transformers import GPT2LMHeadModel
|
6
|
+
except ImportError:
|
7
|
+
pass
|
4
8
|
from multi_choices_parser import DEFAULT_END_SYMB
|
5
9
|
|
6
10
|
|
@@ -35,7 +39,7 @@ def apply_mask_tokens(pred : torch.Tensor, parsers_tokens):
|
|
35
39
|
return pred[~pred.isinf().all(dim=-1)]
|
36
40
|
|
37
41
|
|
38
|
-
def batched_inference_logits(model : GPT2LMHeadModel, input_ids : torch.Tensor,
|
42
|
+
def batched_inference_logits(model : "GPT2LMHeadModel", input_ids : torch.Tensor,
|
39
43
|
attention_mask : torch.Tensor | None = None, batch_size : int = 32,
|
40
44
|
to_cpu=False) -> torch.Tensor:
|
41
45
|
logits = []
|
@@ -96,7 +100,7 @@ def pad_to_same_size(tensors : list[torch.Tensor], padding_value : int) -> torch
|
|
96
100
|
return torch.cat(padded_tensors, dim=0)
|
97
101
|
|
98
102
|
@torch.no_grad()
|
99
|
-
def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int,
|
103
|
+
def divergent_beamsearch(input_ids : torch.Tensor, model : "GPT2LMHeadModel", beam_size : int,
|
100
104
|
max_length : int, parser : Parser, pad_token_id : int, batch_size=32,
|
101
105
|
num_solutions = None, end_symb=DEFAULT_END_SYMB, optimize_gpu_mem=True) -> tuple[torch.Tensor, torch.Tensor]:
|
102
106
|
assert input_ids.shape[0] == 1, "Batch size must be 1"
|
@@ -160,9 +164,11 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
|
|
160
164
|
input_ids_unfinished = torch.cat([input_ids_unfinished[best_tokens_row], best_tokens.unsqueeze(-1)], dim=-1)
|
161
165
|
scores_unfinished = scores_unfinished[best_tokens_row] + best_tokens_logprobs
|
162
166
|
solutions_unfinished = torch.cat([solutions_unfinished[best_tokens_row], best_tokens.unsqueeze(-1)], dim=-1)
|
167
|
+
best_tokens_row = best_tokens_row.tolist()
|
163
168
|
parsers_unfinished = [parsers_unfinished[row].copy() for row in best_tokens_row]
|
164
|
-
for parser, token in zip(parsers_unfinished, best_tokens.tolist()):
|
165
|
-
parser.
|
169
|
+
for parser, token, row in zip(parsers_unfinished, best_tokens.tolist(), best_tokens_row):
|
170
|
+
if not parser.finished:
|
171
|
+
parser.step(token)
|
166
172
|
|
167
173
|
# Special case of vanilla beam search where all answers are valid
|
168
174
|
# Warning : In this case model will not stop on end_of_sentence token
|
@@ -180,10 +186,11 @@ def set_slice_row(x : torch.Tensor, slices : torch.IntTensor, value) -> torch.Te
|
|
180
186
|
x[i].index_fill_(0, indices[i], 0)
|
181
187
|
|
182
188
|
@torch.no_grad()
|
183
|
-
def divergent_logprob(input_ids : torch.Tensor, attention_mask : torch.Tensor | None, model : GPT2LMHeadModel,
|
189
|
+
def divergent_logprob(input_ids : torch.Tensor, attention_mask : torch.Tensor | None, model : "GPT2LMHeadModel",
|
184
190
|
parsers : Parser | list[Parser] | None, batch_size=32,
|
185
191
|
start : int | torch.IntTensor = None, end_symb=DEFAULT_END_SYMB, optimize_gpu_mem=True) -> torch.FloatTensor:
|
186
192
|
if start is None:
|
193
|
+
# Start at 1 because first token logprobs cannot be computed
|
187
194
|
start = 1
|
188
195
|
if isinstance(start, int):
|
189
196
|
start = torch.tensor([start]*input_ids.shape[0])
|
@@ -222,8 +229,9 @@ def divergent_logprob(input_ids : torch.Tensor, attention_mask : torch.Tensor |
|
|
222
229
|
for input_id, att in zip(input_ids[i, start:].tolist(), attention_mask[i, start:].tolist()):
|
223
230
|
if not att:
|
224
231
|
break
|
232
|
+
assert not parser.finished
|
225
233
|
parser.step(input_id)
|
226
|
-
next_tokens =
|
234
|
+
next_tokens = parser.next()
|
227
235
|
try:
|
228
236
|
next_tokens.remove(end_symb)
|
229
237
|
except ValueError:
|
@@ -79,11 +79,13 @@ def test_divergent_beamsearch(model_and_tokenizer, device, end_symb):
|
|
79
79
|
end_symb=end_symb
|
80
80
|
)
|
81
81
|
true_solutions = torch.nn.utils.rnn.pad_sequence([torch.tensor(ans) for ans in tokenized_answers], batch_first=True, padding_value=pad_token_id)
|
82
|
-
|
82
|
+
|
83
83
|
assert torch.isclose(scores[0], logprob_paris_diverge), "Beam search did not return the expected score"
|
84
84
|
assert torch.isclose(scores[1], logprob_madrid), "Beam search did not return the expected score"
|
85
85
|
assert torch.isclose(scores[2], logprob_paris_hilton), "Beam search did not return the expected score"
|
86
86
|
assert torch.isclose(scores[3], logprob_garbage), "Beam search did not return the expected score"
|
87
|
+
assert (solutions == true_solutions).all(), "Beam search did not return the expected solutions"
|
88
|
+
|
87
89
|
|
88
90
|
|
89
91
|
@pytest.mark.parametrize("device", ['cpu', 'cuda'])
|
@@ -215,9 +217,8 @@ def test_element_wise_equivalence_divergent_logprob(fakemodel_and_tokenizer, dev
|
|
215
217
|
'Google is owned by Alphabet'
|
216
218
|
]
|
217
219
|
|
218
|
-
multi_choices_parser = MultiChoicesParser([texts])
|
219
|
-
|
220
220
|
inputs = tokenizer(texts, return_tensors='pt', padding=True).to(device)
|
221
|
+
multi_choices_parser = MultiChoicesParser([[x[1:] for x in tokenizer(texts).input_ids]])
|
221
222
|
|
222
223
|
logprobs_global = divergent_logprob(inputs.input_ids, inputs.attention_mask, model, multi_choices_parser)
|
223
224
|
|
@@ -73,28 +73,30 @@ wheels = [
|
|
73
73
|
|
74
74
|
[[package]]
|
75
75
|
name = "divergent-beamsearch"
|
76
|
-
version = "0.
|
76
|
+
version = "0.2.2"
|
77
77
|
source = { editable = "." }
|
78
78
|
dependencies = [
|
79
79
|
{ name = "multi-choices-parser" },
|
80
80
|
{ name = "torch" },
|
81
|
-
{ name = "transformers" },
|
82
81
|
]
|
83
82
|
|
84
83
|
[package.dev-dependencies]
|
85
84
|
dev = [
|
86
85
|
{ name = "pytest" },
|
86
|
+
{ name = "transformers" },
|
87
87
|
]
|
88
88
|
|
89
89
|
[package.metadata]
|
90
90
|
requires-dist = [
|
91
|
-
{ name = "multi-choices-parser", specifier = ">=0.
|
91
|
+
{ name = "multi-choices-parser", specifier = ">=0.10.0" },
|
92
92
|
{ name = "torch", specifier = ">=2.0.0" },
|
93
|
-
{ name = "transformers", specifier = ">=4.47.1" },
|
94
93
|
]
|
95
94
|
|
96
95
|
[package.metadata.requires-dev]
|
97
|
-
dev = [
|
96
|
+
dev = [
|
97
|
+
{ name = "pytest", specifier = ">=8.3.4" },
|
98
|
+
{ name = "transformers", specifier = ">=4.47.1" },
|
99
|
+
]
|
98
100
|
|
99
101
|
[[package]]
|
100
102
|
name = "filelock"
|
@@ -221,11 +223,18 @@ wheels = [
|
|
221
223
|
|
222
224
|
[[package]]
|
223
225
|
name = "multi-choices-parser"
|
224
|
-
version = "0.
|
225
|
-
source = { registry = "https://pypi.org/simple" }
|
226
|
-
|
227
|
-
|
228
|
-
{ url = "https://files.pythonhosted.org/packages/
|
226
|
+
version = "0.10.0"
|
227
|
+
source = { registry = "https://pypi.org/simple" }
|
228
|
+
wheels = [
|
229
|
+
{ url = "https://files.pythonhosted.org/packages/25/59/233da6ab703cf3243dffd2180d082a91a45caf720723309090cee3353da7/multi_choices_parser-0.10.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3019671e9ed6daa0fb8c746cd9f52557da280d57a3ba938f72a34db336671980", size = 99282 },
|
230
|
+
{ url = "https://files.pythonhosted.org/packages/30/b0/82b5ea3ebb500df180cf15e2d7d43bbcef1d58b122206f0b4616bf1dabf5/multi_choices_parser-0.10.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bebb99bf5096e8f40ee584e5860efa10e9484a50c7747360f313c761e16ed5c5", size = 139668 },
|
231
|
+
{ url = "https://files.pythonhosted.org/packages/2d/f4/b7e12764e7366b01d7fa5fdd177480967492219b1d7ffd5c6a35f8117247/multi_choices_parser-0.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:bee4d14b626fa9f8290670047bc0cc358c0a3dddd0dc104e9e844ed4b1b43203", size = 105669 },
|
232
|
+
{ url = "https://files.pythonhosted.org/packages/d6/08/f6eacab1476d99b64443433e5a683afaf79f8ae6798edf12a7535a7a02af/multi_choices_parser-0.10.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:818d292bfa2d35e0cbee4608cb2f4f9e223c7a8b8d3f94a83dc83d05f4dd71ef", size = 99485 },
|
233
|
+
{ url = "https://files.pythonhosted.org/packages/02/6a/6ecfcf3b14972807cf3eb34d960691116bbdd655ba5466905543b0fb0a53/multi_choices_parser-0.10.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e5ed7468834cf9502d3aa3fe71eaf9b2a28e3e40eb60744ded7a7605eed3612", size = 138236 },
|
234
|
+
{ url = "https://files.pythonhosted.org/packages/62/04/7ab6935f99d275fba202cd21b2cd0fb2f775237c6b57ad247cbb95e4db53/multi_choices_parser-0.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:fbc8c4b6f8bbf9e2ead1c228b6a5be9fac4d91854797b430fcb05d91ba96f8dd", size = 105784 },
|
235
|
+
{ url = "https://files.pythonhosted.org/packages/27/88/8fb06ff9341e4a09b714939d515e583678e1620f2d3a4536e4776a4ad92a/multi_choices_parser-0.10.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:541f5ff7da3cb7fb1b4ef2d114ee02637322c058c84345283c9c3194c7207e31", size = 99523 },
|
236
|
+
{ url = "https://files.pythonhosted.org/packages/e6/da/7f853bb1e676d74c85d25a1023674bcb0407d9a222ce9f65d56de4025dd5/multi_choices_parser-0.10.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ead4849057c50608a48eb498a95a622a2b0151e20c871ed9ec27ed27eb20108d", size = 138524 },
|
237
|
+
{ url = "https://files.pythonhosted.org/packages/ca/2f/7b8baffc032b503fc1075fa0be19c8ab3b56265b8c3a763bfac6c27b835f/multi_choices_parser-0.10.0-cp313-cp313-win_amd64.whl", hash = "sha256:ad17159761164672895efe07c75e1d872c8c40e74e89aa9cb6ff74fd9a81362d", size = 105781 },
|
229
238
|
]
|
230
239
|
|
231
240
|
[[package]]
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{divergent_beamsearch-0.2.0 → divergent_beamsearch-0.2.2}/src/divergent_beamsearch/__init__.py
RENAMED
File without changes
|