divergent-beamsearch 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1 +1 @@
1
- from .algorithm import divergent_beamsearch
1
+ from .algorithm import divergent_beamsearch, divergent_logprob
@@ -35,10 +35,12 @@ def apply_mask_tokens(pred : torch.Tensor, parsers_tokens):
35
35
  return pred[~pred.isinf().all(dim=-1)]
36
36
 
37
37
 
38
- def batched_inference_logits(model : GPT2LMHeadModel, input_ids : torch.Tensor, batch_size : int = 32) -> torch.Tensor:
38
+ def batched_inference_logits(model : GPT2LMHeadModel, input_ids : torch.Tensor, attention_mask : torch.Tensor | None = None, batch_size : int = 32) -> torch.Tensor:
39
39
  logits = []
40
+ if attention_mask is None:
41
+ attention_mask = torch.ones_like(input_ids)
40
42
  for i in range(0, input_ids.shape[0], batch_size):
41
- logits.append(model(input_ids[i:i+batch_size]).logits)
43
+ logits.append(model(input_ids[i:i+batch_size], attention_mask=attention_mask[i:i+batch_size]).logits)
42
44
  return torch.cat(logits, dim=0)
43
45
 
44
46
  def select_mask(source : list, mask : list[bool]) -> list:
@@ -98,7 +100,7 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
98
100
  for _ in range(max_length):
99
101
  if len(input_ids_unfinished) == 0:
100
102
  break
101
- pred = batched_inference_logits(model, input_ids_unfinished.to(device), batch_size)[:, -1].cpu()
103
+ pred = batched_inference_logits(model, input_ids_unfinished.to(device), batch_size=batch_size)[:, -1].cpu()
102
104
  parsers_tokens, can_end = get_parsers_tokens(parsers_unfinished)
103
105
  logprobs = torch.log_softmax(pred, dim=-1)
104
106
  logprobs_filtered = apply_mask_tokens(logprobs, parsers_tokens)
@@ -144,3 +146,60 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
144
146
  solutions_finished = solutions_unfinished[order][:num_solutions]
145
147
 
146
148
  return scores_finished, solutions_finished
149
+
150
+
151
+ def set_slice_row(x : torch.Tensor, slices : torch.IntTensor, value) -> torch.Tensor:
152
+ indices = [torch.arange(start, end) for start, end in slices]
153
+ for i in range(slices.size(0)):
154
+ x[i].index_fill_(0, indices[i], 0)
155
+
156
+ @torch.no_grad()
157
+ def divergent_logprob(input_ids : torch.Tensor, attention_mask : torch.Tensor | None, model : GPT2LMHeadModel, parsers : Parser | list[Parser] | None, batch_size=32, start : int | torch.IntTensor = None) -> torch.FloatTensor:
158
+ if start is None:
159
+ start = 0
160
+ if isinstance(start, int):
161
+ start = torch.tensor([start]*input_ids.shape[0])
162
+ assert start.shape[0] == input_ids.shape[0]
163
+ # -1 because next token offset
164
+ start = start - 1
165
+
166
+ if attention_mask is None:
167
+ attention_mask = torch.ones_like(input_ids)
168
+
169
+ logits = batched_inference_logits(model, input_ids, attention_mask, batch_size).cpu()
170
+ input_ids = input_ids.cpu()
171
+ attention_mask = attention_mask.cpu()
172
+
173
+ logsoftmax = torch.log_softmax(logits, dim=-1)
174
+ log_probs = torch.gather(
175
+ logsoftmax[:, :-1, :], 2, input_ids[:, 1:, None]
176
+ ).squeeze(-1)
177
+ mask = attention_mask[:, 1:].cpu().clone()
178
+
179
+ input_len = attention_mask.sum(-1)
180
+ pos = torch.stack([torch.zeros_like(start), start], dim=-1)
181
+ pos_anti = pos.flip(1)
182
+ pos_anti[:, -1] = input_len
183
+ set_slice_row(mask, pos, 0)
184
+ vanilla_prob = (log_probs * mask).sum(-1)
185
+ if parsers is None:
186
+ parsers = AcceptEverythingParser(model.config.vocab_size)
187
+ if not isinstance(parsers, (tuple, list)):
188
+ parsers = [parsers.copy() for _ in range(len(input_ids))]
189
+ next_possible_tokens = []
190
+ for i, parser in enumerate(parsers):
191
+ # +1 because no next-token offset
192
+ start = pos_anti[i,0]+1
193
+ for input_id, att in zip(input_ids[i, start:].tolist(), attention_mask[i, start:].tolist()):
194
+ if not att:
195
+ break
196
+ parser.step(input_id)
197
+ next_tokens = list(parser.next())
198
+ try:
199
+ next_tokens.remove(end_symb)
200
+ except ValueError:
201
+ pass
202
+ next_possible_tokens.append(next_tokens)
203
+ last_token_log_probs = torch.stack([log1mexp(logsoftmax[i, input_len[i]-1, tokens].logsumexp(-1)).squeeze() for i, tokens in enumerate(next_possible_tokens)])
204
+ prob = vanilla_prob + last_token_log_probs
205
+ return prob
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: divergent-beamsearch
3
- Version: 0.1.2
3
+ Version: 0.1.4
4
4
  Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
5
5
  License-File: LICENCE
6
6
  Requires-Python: >=3.11
@@ -0,0 +1,6 @@
1
+ divergent_beamsearch/__init__.py,sha256=qrpVRoT3d-q1N9fJnzHI2X13e71LDY4-6eLOQ_gwCqQ,62
2
+ divergent_beamsearch/algorithm.py,sha256=d0xU4OWiCEa5icdXZHoV1P-eKYftYMHhfBZMEVNkRXQ,8649
3
+ divergent_beamsearch-0.1.4.dist-info/METADATA,sha256=f1nBA8_Q3a-PgtE6Z5YViFhWVRvuE3aLkPFQyin9274,2826
4
+ divergent_beamsearch-0.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ divergent_beamsearch-0.1.4.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
6
+ divergent_beamsearch-0.1.4.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- divergent_beamsearch/__init__.py,sha256=Z2R1pkj4EEHMKWVZX0upeE_Jtfb6joxgYHuRNxWc8Zo,43
2
- divergent_beamsearch/algorithm.py,sha256=w6aLDOnLwLabmHHOMCEx1Y8P8yaHkFEJxMGNw6f7RsU,6115
3
- divergent_beamsearch-0.1.2.dist-info/METADATA,sha256=fjqa8W8RpRSZOFV-3_D28o2353p54mB7bhEUrsRSoxw,2826
4
- divergent_beamsearch-0.1.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- divergent_beamsearch-0.1.2.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
6
- divergent_beamsearch-0.1.2.dist-info/RECORD,,