divergent-beamsearch 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- divergent_beamsearch/__init__.py +1 -1
- divergent_beamsearch/algorithm.py +3 -3
- {divergent_beamsearch-0.1.3.dist-info → divergent_beamsearch-0.1.5.dist-info}/METADATA +1 -1
- divergent_beamsearch-0.1.5.dist-info/RECORD +6 -0
- divergent_beamsearch-0.1.3.dist-info/RECORD +0 -6
- {divergent_beamsearch-0.1.3.dist-info → divergent_beamsearch-0.1.5.dist-info}/WHEEL +0 -0
- {divergent_beamsearch-0.1.3.dist-info → divergent_beamsearch-0.1.5.dist-info}/licenses/LICENCE +0 -0
divergent_beamsearch/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
from .algorithm import divergent_beamsearch
|
1
|
+
from .algorithm import divergent_beamsearch, divergent_logprob
|
@@ -14,7 +14,7 @@ class Parser:
|
|
14
14
|
def copy(self):
|
15
15
|
raise NotImplementedError
|
16
16
|
|
17
|
-
def get_parsers_tokens(parsers : list[Parser]) -> tuple[list, list[int]]:
|
17
|
+
def get_parsers_tokens(parsers : list[Parser], end_symb) -> tuple[list, list[int]]:
|
18
18
|
parsers_tokens = []
|
19
19
|
can_end = []
|
20
20
|
for parser in parsers:
|
@@ -77,7 +77,7 @@ class AcceptEverythingParser(Parser):
|
|
77
77
|
return self
|
78
78
|
|
79
79
|
@torch.no_grad()
|
80
|
-
def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, parser : Parser, pad_token_id : int, batch_size=32, num_solutions = None) -> tuple[torch.Tensor, torch.Tensor]:
|
80
|
+
def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, parser : Parser, pad_token_id : int, batch_size=32, num_solutions = None, end_symb=end_symb) -> tuple[torch.Tensor, torch.Tensor]:
|
81
81
|
assert input_ids.shape[0] == 1, "Batch size must be 1"
|
82
82
|
device = input_ids.device
|
83
83
|
input_ids = input_ids.cpu()
|
@@ -101,7 +101,7 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
|
|
101
101
|
if len(input_ids_unfinished) == 0:
|
102
102
|
break
|
103
103
|
pred = batched_inference_logits(model, input_ids_unfinished.to(device), batch_size=batch_size)[:, -1].cpu()
|
104
|
-
parsers_tokens, can_end = get_parsers_tokens(parsers_unfinished)
|
104
|
+
parsers_tokens, can_end = get_parsers_tokens(parsers_unfinished, end_symb)
|
105
105
|
logprobs = torch.log_softmax(pred, dim=-1)
|
106
106
|
logprobs_filtered = apply_mask_tokens(logprobs, parsers_tokens)
|
107
107
|
if len(logprobs_filtered):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: divergent-beamsearch
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.5
|
4
4
|
Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
|
5
5
|
License-File: LICENCE
|
6
6
|
Requires-Python: >=3.11
|
@@ -0,0 +1,6 @@
|
|
1
|
+
divergent_beamsearch/__init__.py,sha256=qrpVRoT3d-q1N9fJnzHI2X13e71LDY4-6eLOQ_gwCqQ,62
|
2
|
+
divergent_beamsearch/algorithm.py,sha256=b44kA2_M0HfOSC6LYQmu7W_JnGAl1u9Sz_91jMjvWg0,8688
|
3
|
+
divergent_beamsearch-0.1.5.dist-info/METADATA,sha256=sRitbfgDp8YqLXZzZmZw7Nd1pJWv8T4HJoSqjhO4ztI,2826
|
4
|
+
divergent_beamsearch-0.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
+
divergent_beamsearch-0.1.5.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
|
6
|
+
divergent_beamsearch-0.1.5.dist-info/RECORD,,
|
@@ -1,6 +0,0 @@
|
|
1
|
-
divergent_beamsearch/__init__.py,sha256=Z2R1pkj4EEHMKWVZX0upeE_Jtfb6joxgYHuRNxWc8Zo,43
|
2
|
-
divergent_beamsearch/algorithm.py,sha256=d0xU4OWiCEa5icdXZHoV1P-eKYftYMHhfBZMEVNkRXQ,8649
|
3
|
-
divergent_beamsearch-0.1.3.dist-info/METADATA,sha256=waQn6dvg12V9753CcIQlOR_jcOvfbwAJa24FvR5awy0,2826
|
4
|
-
divergent_beamsearch-0.1.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
-
divergent_beamsearch-0.1.3.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
|
6
|
-
divergent_beamsearch-0.1.3.dist-info/RECORD,,
|
File without changes
|
{divergent_beamsearch-0.1.3.dist-info → divergent_beamsearch-0.1.5.dist-info}/licenses/LICENCE
RENAMED
File without changes
|