divergent-beamsearch 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- divergent_beamsearch/algorithm.py +9 -6
- {divergent_beamsearch-0.1.5.dist-info → divergent_beamsearch-0.1.6.dist-info}/METADATA +2 -2
- divergent_beamsearch-0.1.6.dist-info/RECORD +6 -0
- divergent_beamsearch-0.1.5.dist-info/RECORD +0 -6
- {divergent_beamsearch-0.1.5.dist-info → divergent_beamsearch-0.1.6.dist-info}/WHEEL +0 -0
- {divergent_beamsearch-0.1.5.dist-info → divergent_beamsearch-0.1.6.dist-info}/licenses/LICENCE +0 -0
@@ -1,7 +1,7 @@
|
|
1
1
|
import math
|
2
2
|
import torch
|
3
3
|
from transformers import GPT2LMHeadModel
|
4
|
-
from multi_choices_parser import
|
4
|
+
from multi_choices_parser import DEFAULT_END_SYMB
|
5
5
|
|
6
6
|
|
7
7
|
class Parser:
|
@@ -19,10 +19,10 @@ def get_parsers_tokens(parsers : list[Parser], end_symb) -> tuple[list, list[int
|
|
19
19
|
can_end = []
|
20
20
|
for parser in parsers:
|
21
21
|
tokens = list(parser.next())
|
22
|
-
|
23
|
-
can_end.append(True)
|
22
|
+
try:
|
24
23
|
tokens.remove(end_symb)
|
25
|
-
|
24
|
+
can_end.append(True)
|
25
|
+
except ValueError:
|
26
26
|
can_end.append(False)
|
27
27
|
parsers_tokens.append(tokens)
|
28
28
|
return parsers_tokens, can_end
|
@@ -77,7 +77,7 @@ class AcceptEverythingParser(Parser):
|
|
77
77
|
return self
|
78
78
|
|
79
79
|
@torch.no_grad()
|
80
|
-
def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, parser : Parser, pad_token_id : int, batch_size=32, num_solutions = None, end_symb=
|
80
|
+
def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, parser : Parser, pad_token_id : int, batch_size=32, num_solutions = None, end_symb=DEFAULT_END_SYMB) -> tuple[torch.Tensor, torch.Tensor]:
|
81
81
|
assert input_ids.shape[0] == 1, "Batch size must be 1"
|
82
82
|
device = input_ids.device
|
83
83
|
input_ids = input_ids.cpu()
|
@@ -140,6 +140,7 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
|
|
140
140
|
parser.step(token)
|
141
141
|
|
142
142
|
# Special case of vanilla beam search where all answers are valid
|
143
|
+
# Warning : In this case model will not stop on end_of_sentence token
|
143
144
|
if vanilla:
|
144
145
|
order = scores_unfinished.argsort(descending=True)
|
145
146
|
scores_finished = scores_unfinished[order][:num_solutions]
|
@@ -154,7 +155,9 @@ def set_slice_row(x : torch.Tensor, slices : torch.IntTensor, value) -> torch.Te
|
|
154
155
|
x[i].index_fill_(0, indices[i], 0)
|
155
156
|
|
156
157
|
@torch.no_grad()
|
157
|
-
def divergent_logprob(input_ids : torch.Tensor, attention_mask : torch.Tensor | None, model : GPT2LMHeadModel,
|
158
|
+
def divergent_logprob(input_ids : torch.Tensor, attention_mask : torch.Tensor | None, model : GPT2LMHeadModel,
|
159
|
+
parsers : Parser | list[Parser] | None, batch_size=32,
|
160
|
+
start : int | torch.IntTensor = None, end_symb=DEFAULT_END_SYMB) -> torch.FloatTensor:
|
158
161
|
if start is None:
|
159
162
|
start = 0
|
160
163
|
if isinstance(start, int):
|
@@ -1,10 +1,10 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: divergent-beamsearch
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.6
|
4
4
|
Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
|
5
5
|
License-File: LICENCE
|
6
6
|
Requires-Python: >=3.11
|
7
|
-
Requires-Dist: multi-choices-parser>=0.9.
|
7
|
+
Requires-Dist: multi-choices-parser>=0.9.61
|
8
8
|
Requires-Dist: torch>=2.0.0
|
9
9
|
Requires-Dist: transformers>=4.47.1
|
10
10
|
Description-Content-Type: text/markdown
|
@@ -0,0 +1,6 @@
|
|
1
|
+
divergent_beamsearch/__init__.py,sha256=qrpVRoT3d-q1N9fJnzHI2X13e71LDY4-6eLOQ_gwCqQ,62
|
2
|
+
divergent_beamsearch/algorithm.py,sha256=73BsS5IU1_4Aj11LdQebKofpCO-Mo8BXtDfx-AEYxXA,8835
|
3
|
+
divergent_beamsearch-0.1.6.dist-info/METADATA,sha256=cm-VyQfzk9sklvIkFXgEfI4A4ktWddIO5CvtaL7Vkng,2826
|
4
|
+
divergent_beamsearch-0.1.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
+
divergent_beamsearch-0.1.6.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
|
6
|
+
divergent_beamsearch-0.1.6.dist-info/RECORD,,
|
@@ -1,6 +0,0 @@
|
|
1
|
-
divergent_beamsearch/__init__.py,sha256=qrpVRoT3d-q1N9fJnzHI2X13e71LDY4-6eLOQ_gwCqQ,62
|
2
|
-
divergent_beamsearch/algorithm.py,sha256=b44kA2_M0HfOSC6LYQmu7W_JnGAl1u9Sz_91jMjvWg0,8688
|
3
|
-
divergent_beamsearch-0.1.5.dist-info/METADATA,sha256=sRitbfgDp8YqLXZzZmZw7Nd1pJWv8T4HJoSqjhO4ztI,2826
|
4
|
-
divergent_beamsearch-0.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
-
divergent_beamsearch-0.1.5.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
|
6
|
-
divergent_beamsearch-0.1.5.dist-info/RECORD,,
|
File without changes
|
{divergent_beamsearch-0.1.5.dist-info → divergent_beamsearch-0.1.6.dist-info}/licenses/LICENCE
RENAMED
File without changes
|