divergent-beamsearch 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -47,12 +47,29 @@ def log1mexp(x: torch.Tensor) -> torch.Tensor:
47
47
  (-x.exp()).log1p(),
48
48
  )
49
49
 
50
+ class AcceptEverythingParser:
51
+ def __init__(self, vocab_size : int):
52
+ self.vocab_size = vocab_size
53
+ self.tokens = tuple(range(vocab_size))
54
+
55
+ def step(self, token):
56
+ pass
57
+
58
+ def next(self):
59
+ return self.tokens
60
+
61
+ def copy(self):
62
+ return self
63
+
50
64
  @torch.no_grad()
51
65
  def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, multi_choices_parser : MultiChoicesParser, pad_token_id : int, batch_size=32, num_solutions = None) -> tuple[torch.Tensor, torch.Tensor]:
52
66
  assert input_ids.shape[0] == 1, "Batch size must be 1"
53
67
 
54
68
  if num_solutions is None:
55
69
  num_solutions = beam_size
70
+ vanilla = multi_choices_parser is None
71
+ if vanilla:
72
+ multi_choices_parser = AcceptEverythingParser(model.config.vocab_size)
56
73
 
57
74
  parsers_unfinished = [multi_choices_parser]
58
75
  scores_finished = torch.tensor([], dtype=torch.float)
@@ -73,9 +90,10 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
73
90
  logprobs_filtered = apply_mask_tokens(logprobs, parsers_tokens)
74
91
  if len(logprobs_filtered):
75
92
  topk = torch.topk(logprobs_filtered, beam_size, dim=-1) # shape (batch_size, beam_size)
76
- topk_global = topk.values.flatten().topk(beam_size)
93
+ values = topk.values + scores_unfinished.unsqueeze(-1)
94
+ topk_global = values.flatten().topk(beam_size)
77
95
  best_tokens_row = topk_global.indices // beam_size
78
- best_tokens, best_tokens_logprobs = topk.indices[best_tokens_row, topk_global.indices % beam_size], topk_global.values
96
+ best_tokens, best_tokens_logprobs = topk.indices[best_tokens_row, topk_global.indices % beam_size], topk.values[best_tokens_row, topk_global.indices % beam_size]
79
97
  notinf = ~best_tokens_logprobs.isinf()
80
98
  best_tokens, best_tokens_row, best_tokens_logprobs = best_tokens[notinf], best_tokens_row[notinf], best_tokens_logprobs[notinf]
81
99
  else:
@@ -104,9 +122,11 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
104
122
  parsers_unfinished = [parsers_unfinished[row].copy() for row in best_tokens_row]
105
123
  for parser, token in zip(parsers_unfinished, best_tokens.tolist()):
106
124
  parser.step(token)
125
+
126
+ # Special case of vanilla beam search where all answers are valid
127
+ if vanilla:
128
+ order = scores_unfinished.argsort(descending=True)
129
+ scores_finished = scores_unfinished[order][:num_solutions]
130
+ solutions_finished = solutions_unfinished[order][:num_solutions]
107
131
 
108
132
  return scores_finished, solutions_finished
109
-
110
-
111
-
112
-
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: divergent-beamsearch
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
5
5
  License-File: LICENCE
6
6
  Requires-Python: >=3.11
@@ -0,0 +1,6 @@
1
+ divergent_beamsearch/__init__.py,sha256=Z2R1pkj4EEHMKWVZX0upeE_Jtfb6joxgYHuRNxWc8Zo,43
2
+ divergent_beamsearch/algorithm.py,sha256=0NvVocEHVlIAXnfKhiUW6PEbG_L7uBgE7NGJtaoJ-Rw,6136
3
+ divergent_beamsearch-0.1.1.dist-info/METADATA,sha256=dFlRtT8pvNDcUDZaac59zsLAWHB5M5maMkPO-DKFDGI,2826
4
+ divergent_beamsearch-0.1.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ divergent_beamsearch-0.1.1.dist-info/licenses/LICENCE,sha256=jDQOOFKJxgrQwcEyipwKcKzj5IX_paD_41c3iOjH3qw,1095
6
+ divergent_beamsearch-0.1.1.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- divergent_beamsearch/__init__.py,sha256=Z2R1pkj4EEHMKWVZX0upeE_Jtfb6joxgYHuRNxWc8Zo,43
2
- divergent_beamsearch/algorithm.py,sha256=6cWp6XHepSn1rjqQFkASxd8k3OEUarKNAiqPKfrA78k,5324
3
- divergent_beamsearch-0.1.0.dist-info/METADATA,sha256=UCkp3rgFZ89kmwFVy_N_dEy45NTGN_yFhf-J6WCCR4U,2826
4
- divergent_beamsearch-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- divergent_beamsearch-0.1.0.dist-info/licenses/LICENCE,sha256=jDQOOFKJxgrQwcEyipwKcKzj5IX_paD_41c3iOjH3qw,1095
6
- divergent_beamsearch-0.1.0.dist-info/RECORD,,