divergent-beamsearch 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1 +1 @@
1
- from .algorithm import divergent_beamsearch
1
+ from .algorithm import divergent_beamsearch, divergent_logprob
@@ -35,10 +35,12 @@ def apply_mask_tokens(pred : torch.Tensor, parsers_tokens):
35
35
  return pred[~pred.isinf().all(dim=-1)]
36
36
 
37
37
 
38
- def batched_inference_logits(model : GPT2LMHeadModel, input_ids : torch.Tensor, batch_size : int = 32) -> torch.Tensor:
38
+ def batched_inference_logits(model : GPT2LMHeadModel, input_ids : torch.Tensor, attention_mask : torch.Tensor | None = None, batch_size : int = 32) -> torch.Tensor:
39
39
  logits = []
40
+ if attention_mask is None:
41
+ attention_mask = torch.ones_like(input_ids)
40
42
  for i in range(0, input_ids.shape[0], batch_size):
41
- logits.append(model(input_ids[i:i+batch_size]).logits)
43
+ logits.append(model(input_ids[i:i+batch_size], attention_mask=attention_mask[i:i+batch_size]).logits)
42
44
  return torch.cat(logits, dim=0)
43
45
 
44
46
  def select_mask(source : list, mask : list[bool]) -> list:
@@ -98,7 +100,7 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
98
100
  for _ in range(max_length):
99
101
  if len(input_ids_unfinished) == 0:
100
102
  break
101
- pred = batched_inference_logits(model, input_ids_unfinished.to(device), batch_size)[:, -1].cpu()
103
+ pred = batched_inference_logits(model, input_ids_unfinished.to(device), batch_size=batch_size)[:, -1].cpu()
102
104
  parsers_tokens, can_end = get_parsers_tokens(parsers_unfinished)
103
105
  logprobs = torch.log_softmax(pred, dim=-1)
104
106
  logprobs_filtered = apply_mask_tokens(logprobs, parsers_tokens)
@@ -144,3 +146,60 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
144
146
  solutions_finished = solutions_unfinished[order][:num_solutions]
145
147
 
146
148
  return scores_finished, solutions_finished
149
+
150
+
151
+ def set_slice_row(x : torch.Tensor, slices : torch.IntTensor, value) -> torch.Tensor:
152
+ indices = [torch.arange(start, end) for start, end in slices]
153
+ for i in range(slices.size(0)):
154
+ x[i].index_fill_(0, indices[i], 0)
155
+
156
+ @torch.no_grad()
157
+ def divergent_logprob(input_ids : torch.Tensor, attention_mask : torch.Tensor | None, model : GPT2LMHeadModel, parsers : Parser | list[Parser] | None, batch_size=32, start : int | torch.IntTensor = None) -> torch.FloatTensor:
158
+ if start is None:
159
+ start = 0
160
+ if isinstance(start, int):
161
+ start = torch.tensor([start]*input_ids.shape[0])
162
+ assert start.shape[0] == input_ids.shape[0]
163
+ # -1 because next token offset
164
+ start = start - 1
165
+
166
+ if attention_mask is None:
167
+ attention_mask = torch.ones_like(input_ids)
168
+
169
+ logits = batched_inference_logits(model, input_ids, attention_mask, batch_size).cpu()
170
+ input_ids = input_ids.cpu()
171
+ attention_mask = attention_mask.cpu()
172
+
173
+ logsoftmax = torch.log_softmax(logits, dim=-1)
174
+ log_probs = torch.gather(
175
+ logsoftmax[:, :-1, :], 2, input_ids[:, 1:, None]
176
+ ).squeeze(-1)
177
+ mask = attention_mask[:, 1:].cpu().clone()
178
+
179
+ input_len = attention_mask.sum(-1)
180
+ pos = torch.stack([torch.zeros_like(start), start], dim=-1)
181
+ pos_anti = pos.flip(1)
182
+ pos_anti[:, -1] = input_len
183
+ set_slice_row(mask, pos, 0)
184
+ vanilla_prob = (log_probs * mask).sum(-1)
185
+ if parsers is None:
186
+ parsers = AcceptEverythingParser(model.config.vocab_size)
187
+ if not isinstance(parsers, (tuple, list)):
188
+ parsers = [parsers.copy() for _ in range(len(input_ids))]
189
+ next_possible_tokens = []
190
+ for i, parser in enumerate(parsers):
191
+ # +1 because no next-token offset
192
+ start = pos_anti[i,0]+1
193
+ for input_id, att in zip(input_ids[i, start:].tolist(), attention_mask[i, start:].tolist()):
194
+ if not att:
195
+ break
196
+ parser.step(input_id)
197
+ next_tokens = list(parser.next())
198
+ try:
199
+ next_tokens.remove(end_symb)
200
+ except ValueError:
201
+ pass
202
+ next_possible_tokens.append(next_tokens)
203
+ last_token_log_probs = torch.stack([log1mexp(logsoftmax[i, input_len[i]-1, tokens].logsumexp(-1)).squeeze() for i, tokens in enumerate(next_possible_tokens)])
204
+ prob = vanilla_prob + last_token_log_probs
205
+ return prob
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: divergent-beamsearch
3
- Version: 0.1.2
3
+ Version: 0.1.4
4
4
  Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
5
5
  License-File: LICENCE
6
6
  Requires-Python: >=3.11
@@ -0,0 +1,6 @@
1
+ divergent_beamsearch/__init__.py,sha256=qrpVRoT3d-q1N9fJnzHI2X13e71LDY4-6eLOQ_gwCqQ,62
2
+ divergent_beamsearch/algorithm.py,sha256=d0xU4OWiCEa5icdXZHoV1P-eKYftYMHhfBZMEVNkRXQ,8649
3
+ divergent_beamsearch-0.1.4.dist-info/METADATA,sha256=f1nBA8_Q3a-PgtE6Z5YViFhWVRvuE3aLkPFQyin9274,2826
4
+ divergent_beamsearch-0.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ divergent_beamsearch-0.1.4.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
6
+ divergent_beamsearch-0.1.4.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- divergent_beamsearch/__init__.py,sha256=Z2R1pkj4EEHMKWVZX0upeE_Jtfb6joxgYHuRNxWc8Zo,43
2
- divergent_beamsearch/algorithm.py,sha256=w6aLDOnLwLabmHHOMCEx1Y8P8yaHkFEJxMGNw6f7RsU,6115
3
- divergent_beamsearch-0.1.2.dist-info/METADATA,sha256=fjqa8W8RpRSZOFV-3_D28o2353p54mB7bhEUrsRSoxw,2826
4
- divergent_beamsearch-0.1.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- divergent_beamsearch-0.1.2.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
6
- divergent_beamsearch-0.1.2.dist-info/RECORD,,