divergent-beamsearch 0.1.4__tar.gz → 0.1.5__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: divergent-beamsearch
3
- Version: 0.1.4
3
+ Version: 0.1.5
4
4
  Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
5
5
  License-File: LICENCE
6
6
  Requires-Python: >=3.11
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "divergent-beamsearch"
3
- version = "0.1.4"
3
+ version = "0.1.5"
4
4
  description = "A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject."
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11"
@@ -14,7 +14,7 @@ class Parser:
14
14
  def copy(self):
15
15
  raise NotImplementedError
16
16
 
17
- def get_parsers_tokens(parsers : list[Parser]) -> tuple[list, list[int]]:
17
+ def get_parsers_tokens(parsers : list[Parser], end_symb) -> tuple[list, list[int]]:
18
18
  parsers_tokens = []
19
19
  can_end = []
20
20
  for parser in parsers:
@@ -77,7 +77,7 @@ class AcceptEverythingParser(Parser):
77
77
  return self
78
78
 
79
79
  @torch.no_grad()
80
- def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, parser : Parser, pad_token_id : int, batch_size=32, num_solutions = None) -> tuple[torch.Tensor, torch.Tensor]:
80
+ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, parser : Parser, pad_token_id : int, batch_size=32, num_solutions = None, end_symb=end_symb) -> tuple[torch.Tensor, torch.Tensor]:
81
81
  assert input_ids.shape[0] == 1, "Batch size must be 1"
82
82
  device = input_ids.device
83
83
  input_ids = input_ids.cpu()
@@ -101,7 +101,7 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
101
101
  if len(input_ids_unfinished) == 0:
102
102
  break
103
103
  pred = batched_inference_logits(model, input_ids_unfinished.to(device), batch_size=batch_size)[:, -1].cpu()
104
- parsers_tokens, can_end = get_parsers_tokens(parsers_unfinished)
104
+ parsers_tokens, can_end = get_parsers_tokens(parsers_unfinished, end_symb)
105
105
  logprobs = torch.log_softmax(pred, dim=-1)
106
106
  logprobs_filtered = apply_mask_tokens(logprobs, parsers_tokens)
107
107
  if len(logprobs_filtered):