divergent-beamsearch 0.1.3__tar.gz → 0.1.5__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {divergent_beamsearch-0.1.3 → divergent_beamsearch-0.1.5}/PKG-INFO +1 -1
- {divergent_beamsearch-0.1.3 → divergent_beamsearch-0.1.5}/pyproject.toml +1 -1
- divergent_beamsearch-0.1.5/src/divergent_beamsearch/__init__.py +1 -0
- {divergent_beamsearch-0.1.3 → divergent_beamsearch-0.1.5}/src/divergent_beamsearch/algorithm.py +3 -3
- divergent_beamsearch-0.1.3/src/divergent_beamsearch/__init__.py +0 -1
- {divergent_beamsearch-0.1.3 → divergent_beamsearch-0.1.5}/.gitignore +0 -0
- {divergent_beamsearch-0.1.3 → divergent_beamsearch-0.1.5}/.python-version +0 -0
- {divergent_beamsearch-0.1.3 → divergent_beamsearch-0.1.5}/LICENCE +0 -0
- {divergent_beamsearch-0.1.3 → divergent_beamsearch-0.1.5}/README.md +0 -0
- {divergent_beamsearch-0.1.3 → divergent_beamsearch-0.1.5}/tests/test_beamsearch.py +0 -0
- {divergent_beamsearch-0.1.3 → divergent_beamsearch-0.1.5}/uv.lock +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: divergent-beamsearch
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.5
|
4
4
|
Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
|
5
5
|
License-File: LICENCE
|
6
6
|
Requires-Python: >=3.11
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "divergent-beamsearch"
|
3
|
-
version = "0.1.
|
3
|
+
version = "0.1.5"
|
4
4
|
description = "A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject."
|
5
5
|
readme = "README.md"
|
6
6
|
requires-python = ">=3.11"
|
@@ -0,0 +1 @@
|
|
1
|
+
from .algorithm import divergent_beamsearch, divergent_logprob
|
{divergent_beamsearch-0.1.3 → divergent_beamsearch-0.1.5}/src/divergent_beamsearch/algorithm.py
RENAMED
@@ -14,7 +14,7 @@ class Parser:
|
|
14
14
|
def copy(self):
|
15
15
|
raise NotImplementedError
|
16
16
|
|
17
|
-
def get_parsers_tokens(parsers : list[Parser]) -> tuple[list, list[int]]:
|
17
|
+
def get_parsers_tokens(parsers : list[Parser], end_symb) -> tuple[list, list[int]]:
|
18
18
|
parsers_tokens = []
|
19
19
|
can_end = []
|
20
20
|
for parser in parsers:
|
@@ -77,7 +77,7 @@ class AcceptEverythingParser(Parser):
|
|
77
77
|
return self
|
78
78
|
|
79
79
|
@torch.no_grad()
|
80
|
-
def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, parser : Parser, pad_token_id : int, batch_size=32, num_solutions = None) -> tuple[torch.Tensor, torch.Tensor]:
|
80
|
+
def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, parser : Parser, pad_token_id : int, batch_size=32, num_solutions = None, end_symb=end_symb) -> tuple[torch.Tensor, torch.Tensor]:
|
81
81
|
assert input_ids.shape[0] == 1, "Batch size must be 1"
|
82
82
|
device = input_ids.device
|
83
83
|
input_ids = input_ids.cpu()
|
@@ -101,7 +101,7 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
|
|
101
101
|
if len(input_ids_unfinished) == 0:
|
102
102
|
break
|
103
103
|
pred = batched_inference_logits(model, input_ids_unfinished.to(device), batch_size=batch_size)[:, -1].cpu()
|
104
|
-
parsers_tokens, can_end = get_parsers_tokens(parsers_unfinished)
|
104
|
+
parsers_tokens, can_end = get_parsers_tokens(parsers_unfinished, end_symb)
|
105
105
|
logprobs = torch.log_softmax(pred, dim=-1)
|
106
106
|
logprobs_filtered = apply_mask_tokens(logprobs, parsers_tokens)
|
107
107
|
if len(logprobs_filtered):
|
@@ -1 +0,0 @@
|
|
1
|
-
from .algorithm import divergent_beamsearch
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|