divergent-beamsearch 0.1.5__tar.gz → 0.1.7__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,10 +1,10 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: divergent-beamsearch
3
- Version: 0.1.5
3
+ Version: 0.1.7
4
4
  Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
5
5
  License-File: LICENCE
6
6
  Requires-Python: >=3.11
7
- Requires-Dist: multi-choices-parser>=0.9.57
7
+ Requires-Dist: multi-choices-parser>=0.9.61
8
8
  Requires-Dist: torch>=2.0.0
9
9
  Requires-Dist: transformers>=4.47.1
10
10
  Description-Content-Type: text/markdown
@@ -1,11 +1,11 @@
1
1
  [project]
2
2
  name = "divergent-beamsearch"
3
- version = "0.1.5"
3
+ version = "0.1.7"
4
4
  description = "A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject."
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11"
7
7
  dependencies = [
8
- "multi-choices-parser>=0.9.57",
8
+ "multi-choices-parser>=0.9.61",
9
9
  "torch>=2.0.0",
10
10
  "transformers>=4.47.1",
11
11
  ]
@@ -1,7 +1,7 @@
1
1
  import math
2
2
  import torch
3
3
  from transformers import GPT2LMHeadModel
4
- from multi_choices_parser import MultiChoicesParser, end_symb
4
+ from multi_choices_parser import DEFAULT_END_SYMB
5
5
 
6
6
 
7
7
  class Parser:
@@ -19,10 +19,10 @@ def get_parsers_tokens(parsers : list[Parser], end_symb) -> tuple[list, list[int
19
19
  can_end = []
20
20
  for parser in parsers:
21
21
  tokens = list(parser.next())
22
- if end_symb in tokens:
23
- can_end.append(True)
22
+ try:
24
23
  tokens.remove(end_symb)
25
- else:
24
+ can_end.append(True)
25
+ except ValueError:
26
26
  can_end.append(False)
27
27
  parsers_tokens.append(tokens)
28
28
  return parsers_tokens, can_end
@@ -76,8 +76,14 @@ class AcceptEverythingParser(Parser):
76
76
  def copy(self):
77
77
  return self
78
78
 
79
+ def index_reduce_lists(x : torch.Tensor, indices : list[list[int]], reduce_func=torch.sum) -> torch.Tensor:
80
+ values = []
81
+ for i, index in enumerate(indices):
82
+ values.append(reduce_func(x[i, index], dim=-1))
83
+ return torch.tensor(values, dtype=x.dtype, device=x.device, requires_grad=x.requires_grad)
84
+
79
85
  @torch.no_grad()
80
- def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, parser : Parser, pad_token_id : int, batch_size=32, num_solutions = None, end_symb=end_symb) -> tuple[torch.Tensor, torch.Tensor]:
86
+ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, parser : Parser, pad_token_id : int, batch_size=32, num_solutions = None, end_symb=DEFAULT_END_SYMB) -> tuple[torch.Tensor, torch.Tensor]:
81
87
  assert input_ids.shape[0] == 1, "Batch size must be 1"
82
88
  device = input_ids.device
83
89
  input_ids = input_ids.cpu()
@@ -120,7 +126,8 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
120
126
 
121
127
  scores_finished_current = scores_unfinished[can_end]
122
128
  solutions_finished_current = solutions_unfinished[can_end]
123
- scores_finished_current = scores_finished_current + log1mexp(logprobs[can_end, select_mask(parsers_tokens, can_end)].logsumexp(dim=-1)).squeeze(-1)
129
+ logprob_other_ans = index_reduce_lists(logprobs[can_end], select_mask(parsers_tokens, can_end), reduce_func=torch.logsumexp).squeeze(-1)
130
+ scores_finished_current = scores_finished_current + log1mexp(logprob_other_ans)
124
131
  scores_finished = torch.cat([scores_finished, scores_finished_current])
125
132
  if len(solutions_finished_current):
126
133
  pad = torch.full((len(scores_finished_current), solutions_finished_current.shape[1] - solutions_finished.shape[1]), pad_token_id, dtype=torch.long)
@@ -140,6 +147,7 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
140
147
  parser.step(token)
141
148
 
142
149
  # Special case of vanilla beam search where all answers are valid
150
+ # Warning : In this case model will not stop on end_of_sentence token
143
151
  if vanilla:
144
152
  order = scores_unfinished.argsort(descending=True)
145
153
  scores_finished = scores_unfinished[order][:num_solutions]
@@ -154,7 +162,9 @@ def set_slice_row(x : torch.Tensor, slices : torch.IntTensor, value) -> torch.Te
154
162
  x[i].index_fill_(0, indices[i], 0)
155
163
 
156
164
  @torch.no_grad()
157
- def divergent_logprob(input_ids : torch.Tensor, attention_mask : torch.Tensor | None, model : GPT2LMHeadModel, parsers : Parser | list[Parser] | None, batch_size=32, start : int | torch.IntTensor = None) -> torch.FloatTensor:
165
+ def divergent_logprob(input_ids : torch.Tensor, attention_mask : torch.Tensor | None, model : GPT2LMHeadModel,
166
+ parsers : Parser | list[Parser] | None, batch_size=32,
167
+ start : int | torch.IntTensor = None, end_symb=DEFAULT_END_SYMB) -> torch.FloatTensor:
158
168
  if start is None:
159
169
  start = 0
160
170
  if isinstance(start, int):
@@ -1,10 +1,13 @@
1
1
  import numpy as np
2
2
  import pytest
3
3
  import torch
4
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
5
5
  from multi_choices_parser import MultiChoicesParser
6
6
  from divergent_beamsearch.algorithm import divergent_beamsearch, divergent_logprob, log1mexp
7
- from multi_choices_parser import MultiChoicesParser
7
+ from multi_choices_parser import MultiChoicesParser, DEFAULT_END_SYMB
8
+
9
+
10
+ TEST_END_SYMBS = [DEFAULT_END_SYMB, 'tokenizer']
8
11
 
9
12
  @pytest.fixture
10
13
  def model_and_tokenizer():
@@ -12,8 +15,29 @@ def model_and_tokenizer():
12
15
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
13
16
  return model, tokenizer
14
17
 
18
+ @pytest.fixture
19
+ def fakemodel_and_tokenizer():
20
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
21
+
22
+ # Define a small GPT-2 configuration
23
+ config = GPT2Config(
24
+ vocab_size=tokenizer.vocab_size, # Use the default GPT-2 tokenizer vocab size
25
+ n_positions=64, # Maximum sequence length
26
+ n_ctx=64, # Context window size
27
+ n_embd=8, # Size of the embeddings
28
+ n_layer=1, # Number of layers
29
+ n_head=2, # Number of attention heads
30
+ )
31
+
32
+ # Instantiate a model with the custom configuration
33
+ model = GPT2LMHeadModel(config)
34
+ model.eval()
35
+
36
+ return model, tokenizer
37
+
15
38
  @pytest.mark.parametrize("device", ['cpu', 'cuda'])
16
- def test_divergent_beamsearch(model_and_tokenizer, device):
39
+ @pytest.mark.parametrize("end_symb", TEST_END_SYMBS)
40
+ def test_divergent_beamsearch(model_and_tokenizer, device, end_symb):
17
41
  if device == 'cuda' and not torch.cuda.is_available():
18
42
  pytest.skip("CUDA is not available on this machine.")
19
43
  model, tokenizer = model_and_tokenizer
@@ -24,13 +48,20 @@ def test_divergent_beamsearch(model_and_tokenizer, device):
24
48
  max_length = 10
25
49
  pad_token_id = tokenizer.eos_token_id
26
50
 
27
- possible_answers = [' Paris', ' Paris Hilton']
51
+ possible_answers = [' Paris', ' Madrid', ' Paris Hilton']
28
52
  tokenized_answers = tokenizer(possible_answers).input_ids
29
- multi_choices_parser = MultiChoicesParser([tokenized_answers])
30
53
 
31
- logprob_paris = model(input_ids).logits.cpu().log_softmax(dim=-1)[0, -1, tokenized_answers[0][0]]
32
- logprob_hilton = model(torch.cat([input_ids, torch.tensor(tokenized_answers[1][0], device=device).view(1,1)], dim=-1)).logits.cpu().log_softmax(dim=-1)[0, -1, tokenized_answers[1][1]]
33
- logprob_paris_hilton = logprob_paris + logprob_hilton
54
+ if end_symb == 'tokenizer':
55
+ end_symb = tokenizer.eos_token_id
56
+
57
+ multi_choices_parser = MultiChoicesParser([tokenized_answers], end_symb=end_symb)
58
+
59
+ with torch.no_grad():
60
+ logprob_paris = model(input_ids).logits.cpu().log_softmax(dim=-1)[0, -1, tokenized_answers[0][0]]
61
+ logprob_hilton = model(torch.cat([input_ids, torch.tensor(tokenized_answers[2][0], device=device).view(1,1)], dim=-1)).logits.cpu().log_softmax(dim=-1)[0, -1, tokenized_answers[2][1]]
62
+ logprob_paris_hilton = logprob_paris + logprob_hilton
63
+ logprob_madrid = model(input_ids).logits.cpu().log_softmax(dim=-1)[0, -1, tokenized_answers[1][0]]
64
+ logprob_paris_diverge = logprob_paris + log1mexp(logprob_hilton)
34
65
 
35
66
  scores, solutions = divergent_beamsearch(
36
67
  input_ids=input_ids,
@@ -39,18 +70,22 @@ def test_divergent_beamsearch(model_and_tokenizer, device):
39
70
  max_length=max_length,
40
71
  parser=multi_choices_parser,
41
72
  pad_token_id=pad_token_id,
42
- num_solutions=10
73
+ num_solutions=10,
74
+ end_symb=end_symb
43
75
  )
44
76
  true_solutions = torch.nn.utils.rnn.pad_sequence([torch.tensor(ans) for ans in tokenized_answers], batch_first=True, padding_value=pad_token_id)
45
77
  assert (solutions == true_solutions).all(), "Beam search did not return the expected solutions"
46
- assert scores[0] == logprob_paris + log1mexp(logprob_hilton), "Beam search did not return the expected score"
47
- assert scores[1] == logprob_paris_hilton, "Beam search did not return the expected score"
78
+ assert torch.isclose(scores[0], logprob_paris_diverge), "Beam search did not return the expected score"
79
+ assert torch.isclose(scores[1], logprob_madrid), "Beam search did not return the expected score"
80
+ assert torch.isclose(scores[2], logprob_paris_hilton), "Beam search did not return the expected score"
81
+
48
82
 
49
83
  @pytest.mark.parametrize("device", ['cpu', 'cuda'])
50
- def test_divergent_logprob(model_and_tokenizer, device):
84
+ @pytest.mark.parametrize("end_symb", TEST_END_SYMBS)
85
+ def test_divergent_logprob(fakemodel_and_tokenizer, device, end_symb):
51
86
  if device == 'cuda' and not torch.cuda.is_available():
52
87
  pytest.skip("CUDA is not available on this machine.")
53
- model, tokenizer = model_and_tokenizer
88
+ model, tokenizer = fakemodel_and_tokenizer
54
89
  model.to(device)
55
90
  prompts = [
56
91
  "The capital of France is Paris",
@@ -63,10 +98,14 @@ def test_divergent_logprob(model_and_tokenizer, device):
63
98
 
64
99
  possible_answers = [' Paris', ' Paris Hilton']
65
100
  tokenized_answers = tokenizer(possible_answers).input_ids
66
- multi_choices_parser = MultiChoicesParser([tokenized_answers])
101
+
102
+ if end_symb == 'tokenizer':
103
+ end_symb = tokenizer.eos_token_id
104
+
105
+ multi_choices_parser = MultiChoicesParser([tokenized_answers], end_symb=end_symb)
67
106
 
68
107
  input_len = attention_mask.sum(-1).cpu()
69
- probs = divergent_logprob(input_ids, attention_mask, model, multi_choices_parser, start=input_len - torch.tensor([1,2]))
108
+ probs = divergent_logprob(input_ids, attention_mask, model, multi_choices_parser, start=input_len - torch.tensor([1,2]), end_symb=end_symb)
70
109
 
71
110
  input_ids_1st = tokenizer("The capital of France is Paris Hilton", return_tensors='pt').input_ids.to(device)
72
111
  logprobs_1st = model(input_ids_1st).logits.cpu().log_softmax(dim=-1)
@@ -156,4 +195,4 @@ def test_vanilla_beamsearch(model_and_tokenizer, device):
156
195
  ]
157
196
  assert np.isclose(
158
197
  scores.cpu().numpy(), np.array([-8.1361, -8.7745, -9.1053]), atol=0.0001
159
- ).all()
198
+ ).all()
@@ -73,7 +73,7 @@ wheels = [
73
73
 
74
74
  [[package]]
75
75
  name = "divergent-beamsearch"
76
- version = "0.1.1"
76
+ version = "0.1.5"
77
77
  source = { editable = "." }
78
78
  dependencies = [
79
79
  { name = "multi-choices-parser" },
@@ -88,7 +88,7 @@ dev = [
88
88
 
89
89
  [package.metadata]
90
90
  requires-dist = [
91
- { name = "multi-choices-parser", specifier = ">=0.9.57" },
91
+ { name = "multi-choices-parser", specifier = ">=0.9.61" },
92
92
  { name = "torch", specifier = ">=2.0.0" },
93
93
  { name = "transformers", specifier = ">=4.47.1" },
94
94
  ]
@@ -221,11 +221,11 @@ wheels = [
221
221
 
222
222
  [[package]]
223
223
  name = "multi-choices-parser"
224
- version = "0.9.57"
224
+ version = "0.9.61"
225
225
  source = { registry = "https://pypi.org/simple" }
226
- sdist = { url = "https://files.pythonhosted.org/packages/69/55/e2228a3839d46282947f4383bc0a588751b164a24c162e4642a65ffe906f/multi_choices_parser-0.9.57.tar.gz", hash = "sha256:f4f42c4a6abbaa5a2529b976d6a4d756edb8fa422a59b98f30fd2a4331995600", size = 7662 }
226
+ sdist = { url = "https://files.pythonhosted.org/packages/e2/17/90a6125a2145c03e39c3e7f78f65121eb14dedb9de8b40aee3c8a24a709b/multi_choices_parser-0.9.61.tar.gz", hash = "sha256:be932cac4aeabe9ee057c6d4592ea4325a0a92e52758d77a9e08bafa2cd23294", size = 7889 }
227
227
  wheels = [
228
- { url = "https://files.pythonhosted.org/packages/4d/bf/e8d829acca04bc1429ca15582321c3b4702db0e071dbddcf3246f1895956/multi_choices_parser-0.9.57-py3-none-any.whl", hash = "sha256:ff58ac7c440d3129ffe89420039c4ddd6483e54b0526f68d25933e4f62d3c8d2", size = 6702 },
228
+ { url = "https://files.pythonhosted.org/packages/3c/4f/c5a514a510779202ff37505220edfba9154ceff31958ed71fa1878781af9/multi_choices_parser-0.9.61-py3-none-any.whl", hash = "sha256:36bc367bceb66bbfb1bea26d9a38aa9cd10273b54cef331dd7c69da582fb9c2a", size = 6870 },
229
229
  ]
230
230
 
231
231
  [[package]]
@@ -618,21 +618,21 @@ dependencies = [
618
618
  { name = "fsspec" },
619
619
  { name = "jinja2" },
620
620
  { name = "networkx" },
621
- { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
622
- { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
623
- { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
624
- { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
625
- { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
626
- { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
627
- { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
628
- { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
629
- { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
630
- { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
631
- { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
632
- { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
621
+ { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
622
+ { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
623
+ { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
624
+ { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
625
+ { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
626
+ { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
627
+ { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
628
+ { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
629
+ { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
630
+ { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
631
+ { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
632
+ { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
633
633
  { name = "setuptools", marker = "python_full_version >= '3.12'" },
634
634
  { name = "sympy" },
635
- { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'" },
635
+ { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and platform_system == 'Linux'" },
636
636
  { name = "typing-extensions" },
637
637
  ]
638
638
  wheels = [
@@ -652,7 +652,7 @@ name = "tqdm"
652
652
  version = "4.67.1"
653
653
  source = { registry = "https://pypi.org/simple" }
654
654
  dependencies = [
655
- { name = "colorama", marker = "sys_platform == 'win32'" },
655
+ { name = "colorama", marker = "platform_system == 'Windows'" },
656
656
  ]
657
657
  sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 }
658
658
  wheels = [