divergent-beamsearch 0.1.2__tar.gz → 0.1.3__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: divergent-beamsearch
3
- Version: 0.1.2
3
+ Version: 0.1.3
4
4
  Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
5
5
  License-File: LICENCE
6
6
  Requires-Python: >=3.11
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "divergent-beamsearch"
3
- version = "0.1.2"
3
+ version = "0.1.3"
4
4
  description = "A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject."
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11"
@@ -35,10 +35,12 @@ def apply_mask_tokens(pred : torch.Tensor, parsers_tokens):
35
35
  return pred[~pred.isinf().all(dim=-1)]
36
36
 
37
37
 
38
- def batched_inference_logits(model : GPT2LMHeadModel, input_ids : torch.Tensor, batch_size : int = 32) -> torch.Tensor:
38
+ def batched_inference_logits(model : GPT2LMHeadModel, input_ids : torch.Tensor, attention_mask : torch.Tensor | None = None, batch_size : int = 32) -> torch.Tensor:
39
39
  logits = []
40
+ if attention_mask is None:
41
+ attention_mask = torch.ones_like(input_ids)
40
42
  for i in range(0, input_ids.shape[0], batch_size):
41
- logits.append(model(input_ids[i:i+batch_size]).logits)
43
+ logits.append(model(input_ids[i:i+batch_size], attention_mask=attention_mask[i:i+batch_size]).logits)
42
44
  return torch.cat(logits, dim=0)
43
45
 
44
46
  def select_mask(source : list, mask : list[bool]) -> list:
@@ -98,7 +100,7 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
98
100
  for _ in range(max_length):
99
101
  if len(input_ids_unfinished) == 0:
100
102
  break
101
- pred = batched_inference_logits(model, input_ids_unfinished.to(device), batch_size)[:, -1].cpu()
103
+ pred = batched_inference_logits(model, input_ids_unfinished.to(device), batch_size=batch_size)[:, -1].cpu()
102
104
  parsers_tokens, can_end = get_parsers_tokens(parsers_unfinished)
103
105
  logprobs = torch.log_softmax(pred, dim=-1)
104
106
  logprobs_filtered = apply_mask_tokens(logprobs, parsers_tokens)
@@ -144,3 +146,60 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
144
146
  solutions_finished = solutions_unfinished[order][:num_solutions]
145
147
 
146
148
  return scores_finished, solutions_finished
149
+
150
+
151
+ def set_slice_row(x : torch.Tensor, slices : torch.IntTensor, value) -> torch.Tensor:
152
+ indices = [torch.arange(start, end) for start, end in slices]
153
+ for i in range(slices.size(0)):
154
+ x[i].index_fill_(0, indices[i], 0)
155
+
156
+ @torch.no_grad()
157
+ def divergent_logprob(input_ids : torch.Tensor, attention_mask : torch.Tensor | None, model : GPT2LMHeadModel, parsers : Parser | list[Parser] | None, batch_size=32, start : int | torch.IntTensor = None) -> torch.FloatTensor:
158
+ if start is None:
159
+ start = 0
160
+ if isinstance(start, int):
161
+ start = torch.tensor([start]*input_ids.shape[0])
162
+ assert start.shape[0] == input_ids.shape[0]
163
+ # -1 because next token offset
164
+ start = start - 1
165
+
166
+ if attention_mask is None:
167
+ attention_mask = torch.ones_like(input_ids)
168
+
169
+ logits = batched_inference_logits(model, input_ids, attention_mask, batch_size).cpu()
170
+ input_ids = input_ids.cpu()
171
+ attention_mask = attention_mask.cpu()
172
+
173
+ logsoftmax = torch.log_softmax(logits, dim=-1)
174
+ log_probs = torch.gather(
175
+ logsoftmax[:, :-1, :], 2, input_ids[:, 1:, None]
176
+ ).squeeze(-1)
177
+ mask = attention_mask[:, 1:].cpu().clone()
178
+
179
+ input_len = attention_mask.sum(-1)
180
+ pos = torch.stack([torch.zeros_like(start), start], dim=-1)
181
+ pos_anti = pos.flip(1)
182
+ pos_anti[:, -1] = input_len
183
+ set_slice_row(mask, pos, 0)
184
+ vanilla_prob = (log_probs * mask).sum(-1)
185
+ if parsers is None:
186
+ parsers = AcceptEverythingParser(model.config.vocab_size)
187
+ if not isinstance(parsers, (tuple, list)):
188
+ parsers = [parsers.copy() for _ in range(len(input_ids))]
189
+ next_possible_tokens = []
190
+ for i, parser in enumerate(parsers):
191
+ # +1 because no next-token offset
192
+ start = pos_anti[i,0]+1
193
+ for input_id, att in zip(input_ids[i, start:].tolist(), attention_mask[i, start:].tolist()):
194
+ if not att:
195
+ break
196
+ parser.step(input_id)
197
+ next_tokens = list(parser.next())
198
+ try:
199
+ next_tokens.remove(end_symb)
200
+ except ValueError:
201
+ pass
202
+ next_possible_tokens.append(next_tokens)
203
+ last_token_log_probs = torch.stack([log1mexp(logsoftmax[i, input_len[i]-1, tokens].logsumexp(-1)).squeeze() for i, tokens in enumerate(next_possible_tokens)])
204
+ prob = vanilla_prob + last_token_log_probs
205
+ return prob
@@ -3,7 +3,7 @@ import pytest
3
3
  import torch
4
4
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
5
5
  from multi_choices_parser import MultiChoicesParser
6
- from divergent_beamsearch.algorithm import divergent_beamsearch, log1mexp
6
+ from divergent_beamsearch.algorithm import divergent_beamsearch, divergent_logprob, log1mexp
7
7
  from multi_choices_parser import MultiChoicesParser
8
8
 
9
9
  @pytest.fixture
@@ -46,6 +46,40 @@ def test_divergent_beamsearch(model_and_tokenizer, device):
46
46
  assert scores[0] == logprob_paris + log1mexp(logprob_hilton), "Beam search did not return the expected score"
47
47
  assert scores[1] == logprob_paris_hilton, "Beam search did not return the expected score"
48
48
 
49
+ @pytest.mark.parametrize("device", ['cpu', 'cuda'])
50
+ def test_divergent_logprob(model_and_tokenizer, device):
51
+ if device == 'cuda' and not torch.cuda.is_available():
52
+ pytest.skip("CUDA is not available on this machine.")
53
+ model, tokenizer = model_and_tokenizer
54
+ model.to(device)
55
+ prompts = [
56
+ "The capital of France is Paris",
57
+ "The top model Paris Hilton"
58
+ ]
59
+ tokenizer.pad_token = tokenizer.eos_token
60
+ inp = tokenizer(prompts, return_tensors="pt", padding=True)
61
+ input_ids = inp.input_ids.to(device)
62
+ attention_mask = inp.attention_mask.to(device)
63
+
64
+ possible_answers = [' Paris', ' Paris Hilton']
65
+ tokenized_answers = tokenizer(possible_answers).input_ids
66
+ multi_choices_parser = MultiChoicesParser([tokenized_answers])
67
+
68
+ input_len = attention_mask.sum(-1).cpu()
69
+ probs = divergent_logprob(input_ids, attention_mask, model, multi_choices_parser, start=input_len - torch.tensor([1,2]))
70
+
71
+ input_ids_1st = tokenizer("The capital of France is Paris Hilton", return_tensors='pt').input_ids.to(device)
72
+ logprobs_1st = model(input_ids_1st).logits.cpu().log_softmax(dim=-1)
73
+ logprob_paris = logprobs_1st[0, input_ids_1st.shape[1]-3, tokenized_answers[1][0]] # P(Paris | The capital of France is)
74
+ logprob_hilton = logprobs_1st[0, input_ids_1st.shape[1]-2, tokenized_answers[1][1]] # P(Hilton | The capital of France is Paris)
75
+
76
+ input_ids_2nd = tokenizer("The top model Paris Hilton", return_tensors='pt').input_ids.to(device)
77
+ logprobs_2nd = model(input_ids_2nd).logits.cpu().log_softmax(dim=-1)
78
+ logprob_paris_hilton = logprobs_2nd[0, -3, tokenized_answers[1][0]] + logprobs_2nd[0, -2, tokenized_answers[1][1]] # P(Paris Hilton | The top model)
79
+
80
+ assert torch.isclose(probs[0], logprob_paris + log1mexp(logprob_hilton)), "P_divergent(Paris | The capital of France is) is incorrect"
81
+ assert torch.isclose(probs[1], logprob_paris_hilton), "P_divergent(Paris Hilton | The top model) is incorrect"
82
+
49
83
  @pytest.mark.parametrize("device", ['cpu', 'cuda'])
50
84
  def test_vanilla_beamsearch(model_and_tokenizer, device):
51
85
  if device == 'cuda' and not torch.cuda.is_available():