divergent-beamsearch 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1 +1 @@
1
- from .algorithm import divergent_beamsearch
1
+ from .algorithm import divergent_beamsearch, divergent_logprob
@@ -14,7 +14,7 @@ class Parser:
14
14
  def copy(self):
15
15
  raise NotImplementedError
16
16
 
17
- def get_parsers_tokens(parsers : list[Parser]) -> tuple[list, list[int]]:
17
+ def get_parsers_tokens(parsers : list[Parser], end_symb) -> tuple[list, list[int]]:
18
18
  parsers_tokens = []
19
19
  can_end = []
20
20
  for parser in parsers:
@@ -77,7 +77,7 @@ class AcceptEverythingParser(Parser):
77
77
  return self
78
78
 
79
79
  @torch.no_grad()
80
- def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, parser : Parser, pad_token_id : int, batch_size=32, num_solutions = None) -> tuple[torch.Tensor, torch.Tensor]:
80
+ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, parser : Parser, pad_token_id : int, batch_size=32, num_solutions = None, end_symb=end_symb) -> tuple[torch.Tensor, torch.Tensor]:
81
81
  assert input_ids.shape[0] == 1, "Batch size must be 1"
82
82
  device = input_ids.device
83
83
  input_ids = input_ids.cpu()
@@ -101,7 +101,7 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
101
101
  if len(input_ids_unfinished) == 0:
102
102
  break
103
103
  pred = batched_inference_logits(model, input_ids_unfinished.to(device), batch_size=batch_size)[:, -1].cpu()
104
- parsers_tokens, can_end = get_parsers_tokens(parsers_unfinished)
104
+ parsers_tokens, can_end = get_parsers_tokens(parsers_unfinished, end_symb)
105
105
  logprobs = torch.log_softmax(pred, dim=-1)
106
106
  logprobs_filtered = apply_mask_tokens(logprobs, parsers_tokens)
107
107
  if len(logprobs_filtered):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: divergent-beamsearch
3
- Version: 0.1.3
3
+ Version: 0.1.5
4
4
  Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
5
5
  License-File: LICENCE
6
6
  Requires-Python: >=3.11
@@ -0,0 +1,6 @@
1
+ divergent_beamsearch/__init__.py,sha256=qrpVRoT3d-q1N9fJnzHI2X13e71LDY4-6eLOQ_gwCqQ,62
2
+ divergent_beamsearch/algorithm.py,sha256=b44kA2_M0HfOSC6LYQmu7W_JnGAl1u9Sz_91jMjvWg0,8688
3
+ divergent_beamsearch-0.1.5.dist-info/METADATA,sha256=sRitbfgDp8YqLXZzZmZw7Nd1pJWv8T4HJoSqjhO4ztI,2826
4
+ divergent_beamsearch-0.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ divergent_beamsearch-0.1.5.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
6
+ divergent_beamsearch-0.1.5.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- divergent_beamsearch/__init__.py,sha256=Z2R1pkj4EEHMKWVZX0upeE_Jtfb6joxgYHuRNxWc8Zo,43
2
- divergent_beamsearch/algorithm.py,sha256=d0xU4OWiCEa5icdXZHoV1P-eKYftYMHhfBZMEVNkRXQ,8649
3
- divergent_beamsearch-0.1.3.dist-info/METADATA,sha256=waQn6dvg12V9753CcIQlOR_jcOvfbwAJa24FvR5awy0,2826
4
- divergent_beamsearch-0.1.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- divergent_beamsearch-0.1.3.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
6
- divergent_beamsearch-0.1.3.dist-info/RECORD,,