divergent-beamsearch 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -47,12 +47,29 @@ def log1mexp(x: torch.Tensor) -> torch.Tensor:
47
47
  (-x.exp()).log1p(),
48
48
  )
49
49
 
50
+ class AcceptEverythingParser:
51
+ def __init__(self, vocab_size : int):
52
+ self.vocab_size = vocab_size
53
+ self.tokens = tuple(range(vocab_size))
54
+
55
+ def step(self, token):
56
+ pass
57
+
58
+ def next(self):
59
+ return self.tokens
60
+
61
+ def copy(self):
62
+ return self
63
+
50
64
  @torch.no_grad()
51
65
  def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, multi_choices_parser : MultiChoicesParser, pad_token_id : int, batch_size=32, num_solutions = None) -> tuple[torch.Tensor, torch.Tensor]:
52
66
  assert input_ids.shape[0] == 1, "Batch size must be 1"
53
67
 
54
68
  if num_solutions is None:
55
69
  num_solutions = beam_size
70
+ vanilla = multi_choices_parser is None
71
+ if vanilla:
72
+ multi_choices_parser = AcceptEverythingParser(model.config.vocab_size)
56
73
 
57
74
  parsers_unfinished = [multi_choices_parser]
58
75
  scores_finished = torch.tensor([], dtype=torch.float)
@@ -73,9 +90,10 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
73
90
  logprobs_filtered = apply_mask_tokens(logprobs, parsers_tokens)
74
91
  if len(logprobs_filtered):
75
92
  topk = torch.topk(logprobs_filtered, beam_size, dim=-1) # shape (batch_size, beam_size)
76
- topk_global = topk.values.flatten().topk(beam_size)
93
+ values = topk.values + scores_unfinished.unsqueeze(-1)
94
+ topk_global = values.flatten().topk(beam_size)
77
95
  best_tokens_row = topk_global.indices // beam_size
78
- best_tokens, best_tokens_logprobs = topk.indices[best_tokens_row, topk_global.indices % beam_size], topk_global.values
96
+ best_tokens, best_tokens_logprobs = topk.indices[best_tokens_row, topk_global.indices % beam_size], topk.values[best_tokens_row, topk_global.indices % beam_size]
79
97
  notinf = ~best_tokens_logprobs.isinf()
80
98
  best_tokens, best_tokens_row, best_tokens_logprobs = best_tokens[notinf], best_tokens_row[notinf], best_tokens_logprobs[notinf]
81
99
  else:
@@ -104,9 +122,11 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
104
122
  parsers_unfinished = [parsers_unfinished[row].copy() for row in best_tokens_row]
105
123
  for parser, token in zip(parsers_unfinished, best_tokens.tolist()):
106
124
  parser.step(token)
125
+
126
+ # Special case of vanilla beam search where all answers are valid
127
+ if vanilla:
128
+ order = scores_unfinished.argsort(descending=True)
129
+ scores_finished = scores_unfinished[order][:num_solutions]
130
+ solutions_finished = solutions_unfinished[order][:num_solutions]
107
131
 
108
132
  return scores_finished, solutions_finished
109
-
110
-
111
-
112
-
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: divergent-beamsearch
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
5
5
  License-File: LICENCE
6
6
  Requires-Python: >=3.11
@@ -0,0 +1,6 @@
1
+ divergent_beamsearch/__init__.py,sha256=Z2R1pkj4EEHMKWVZX0upeE_Jtfb6joxgYHuRNxWc8Zo,43
2
+ divergent_beamsearch/algorithm.py,sha256=0NvVocEHVlIAXnfKhiUW6PEbG_L7uBgE7NGJtaoJ-Rw,6136
3
+ divergent_beamsearch-0.1.1.dist-info/METADATA,sha256=dFlRtT8pvNDcUDZaac59zsLAWHB5M5maMkPO-DKFDGI,2826
4
+ divergent_beamsearch-0.1.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ divergent_beamsearch-0.1.1.dist-info/licenses/LICENCE,sha256=jDQOOFKJxgrQwcEyipwKcKzj5IX_paD_41c3iOjH3qw,1095
6
+ divergent_beamsearch-0.1.1.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- divergent_beamsearch/__init__.py,sha256=Z2R1pkj4EEHMKWVZX0upeE_Jtfb6joxgYHuRNxWc8Zo,43
2
- divergent_beamsearch/algorithm.py,sha256=6cWp6XHepSn1rjqQFkASxd8k3OEUarKNAiqPKfrA78k,5324
3
- divergent_beamsearch-0.1.0.dist-info/METADATA,sha256=UCkp3rgFZ89kmwFVy_N_dEy45NTGN_yFhf-J6WCCR4U,2826
4
- divergent_beamsearch-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- divergent_beamsearch-0.1.0.dist-info/licenses/LICENCE,sha256=jDQOOFKJxgrQwcEyipwKcKzj5IX_paD_41c3iOjH3qw,1095
6
- divergent_beamsearch-0.1.0.dist-info/RECORD,,