divergent-beamsearch 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,112 +1,146 @@
1
- import math
2
- import torch
3
- from transformers import GPT2LMHeadModel
4
- from multi_choices_parser import MultiChoicesParser, end_symb
5
-
6
-
7
- def get_parsers_tokens(parsers : list[MultiChoicesParser]) -> tuple[list, list[int]]:
8
- parsers_tokens = []
9
- can_end = []
10
- for parser in parsers:
11
- tokens = list(parser.next())
12
- if end_symb in tokens:
13
- can_end.append(True)
14
- tokens.remove(end_symb)
15
- else:
16
- can_end.append(False)
17
- parsers_tokens.append(tokens)
18
- return parsers_tokens, can_end
19
-
20
- def apply_mask_tokens(pred : torch.Tensor, parsers_tokens):
21
- mask = torch.ones_like(pred, dtype=torch.bool)
22
- for tokens in parsers_tokens:
23
- mask[:, tokens] = False
24
- pred[mask] = -float('inf')
25
- return pred[~pred.isinf().all(dim=-1)]
26
-
27
-
28
- def batched_inference_logits(model : GPT2LMHeadModel, input_ids : torch.Tensor, batch_size : int = 32) -> torch.Tensor:
29
- logits = []
30
- for i in range(0, input_ids.shape[0], batch_size):
31
- logits.append(model(input_ids[i:i+batch_size]).logits)
32
- return torch.cat(logits, dim=0)
33
-
34
- def select_mask(source : list, mask : list[bool]) -> list:
35
- assert len(source) == len(mask)
36
- return [x for x, m in zip(source, mask) if m]
37
-
38
-
39
- def log1mexp(x: torch.Tensor) -> torch.Tensor:
40
- """Numerically accurate evaluation of log(1 - exp(x)) for x < 0.
41
- See [Maechler2012accurate]_ for details.
42
- """
43
- mask = -math.log(2) < x # x < 0
44
- return torch.where(
45
- mask,
46
- (-x.expm1()).log(),
47
- (-x.exp()).log1p(),
48
- )
49
-
50
- @torch.no_grad()
51
- def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, multi_choices_parser : MultiChoicesParser, pad_token_id : int, batch_size=32, num_solutions = None) -> tuple[torch.Tensor, torch.Tensor]:
52
- assert input_ids.shape[0] == 1, "Batch size must be 1"
53
-
54
- if num_solutions is None:
55
- num_solutions = beam_size
56
-
57
- parsers_unfinished = [multi_choices_parser]
58
- scores_finished = torch.tensor([], dtype=torch.float)
59
- solutions_finished = torch.tensor([], dtype=torch.long).view(0,0)
60
-
61
- input_ids_unfinished = input_ids
62
- scores_unfinished = torch.tensor([0.0], dtype=torch.float)
63
- solutions_unfinished = torch.tensor([], dtype=torch.long).view(1,0)
64
-
65
-
66
- for _ in range(max_length):
67
- if len(input_ids_unfinished) == 0:
68
- break
69
- pred = batched_inference_logits(model, input_ids_unfinished, batch_size)[:, -1].cpu()
70
- parsers_tokens, can_end = get_parsers_tokens(parsers_unfinished)
71
- # input_ids_unfinished = input_ids_unfinished[~torch.tensor(can_only_end)]
72
- logprobs = torch.log_softmax(pred, dim=-1)
73
- logprobs_filtered = apply_mask_tokens(logprobs, parsers_tokens)
74
- if len(logprobs_filtered):
75
- topk = torch.topk(logprobs_filtered, beam_size, dim=-1) # shape (batch_size, beam_size)
76
- topk_global = topk.values.flatten().topk(beam_size)
77
- best_tokens_row = topk_global.indices // beam_size
78
- best_tokens, best_tokens_logprobs = topk.indices[best_tokens_row, topk_global.indices % beam_size], topk_global.values
79
- notinf = ~best_tokens_logprobs.isinf()
80
- best_tokens, best_tokens_row, best_tokens_logprobs = best_tokens[notinf], best_tokens_row[notinf], best_tokens_logprobs[notinf]
81
- else:
82
- best_tokens = torch.tensor([], dtype=torch.long)
83
- best_tokens_row = torch.tensor([], dtype=torch.long)
84
- best_tokens_logprobs = torch.tensor([], dtype=torch.float)
85
-
86
-
87
- scores_finished_current = scores_unfinished[can_end]
88
- solutions_finished_current = solutions_unfinished[can_end]
89
- scores_finished_current = scores_finished_current + log1mexp(logprobs[can_end, select_mask(parsers_tokens, can_end)].logsumexp(dim=-1)).squeeze(-1)
90
- scores_finished = torch.cat([scores_finished, scores_finished_current])
91
- if len(solutions_finished_current):
92
- pad = torch.full((len(scores_finished_current), solutions_finished_current.shape[1] - solutions_finished.shape[1]), pad_token_id, dtype=torch.long)
93
- solutions_finished = torch.cat([solutions_finished.view(-1, solutions_finished_current.shape[1]+pad.shape[1]), torch.cat([solutions_finished_current, pad], dim=1)], dim=0)
94
- if solutions_finished.numel():
95
- # Keep num_solutions best solutions in finished
96
- order = scores_finished.argsort(descending=True)
97
- solutions_finished = solutions_finished[order][:num_solutions]
98
- scores_finished = scores_finished[order][:num_solutions]
99
-
100
-
101
- input_ids_unfinished = torch.cat([input_ids_unfinished[best_tokens_row], best_tokens.unsqueeze(-1)], dim=-1)
102
- scores_unfinished = scores_unfinished[best_tokens_row] + best_tokens_logprobs
103
- solutions_unfinished = torch.cat([solutions_unfinished[best_tokens_row], best_tokens.unsqueeze(-1)], dim=-1)
104
- parsers_unfinished = [parsers_unfinished[row].copy() for row in best_tokens_row]
105
- for parser, token in zip(parsers_unfinished, best_tokens.tolist()):
106
- parser.step(token)
107
-
108
- return scores_finished, solutions_finished
109
-
110
-
111
-
112
-
1
+ import math
2
+ import torch
3
+ from transformers import GPT2LMHeadModel
4
+ from multi_choices_parser import MultiChoicesParser, end_symb
5
+
6
+
7
+ class Parser:
8
+ def step(self, token):
9
+ raise NotImplementedError
10
+
11
+ def next(self):
12
+ raise NotImplementedError
13
+
14
+ def copy(self):
15
+ raise NotImplementedError
16
+
17
+ def get_parsers_tokens(parsers : list[Parser]) -> tuple[list, list[int]]:
18
+ parsers_tokens = []
19
+ can_end = []
20
+ for parser in parsers:
21
+ tokens = list(parser.next())
22
+ if end_symb in tokens:
23
+ can_end.append(True)
24
+ tokens.remove(end_symb)
25
+ else:
26
+ can_end.append(False)
27
+ parsers_tokens.append(tokens)
28
+ return parsers_tokens, can_end
29
+
30
+ def apply_mask_tokens(pred : torch.Tensor, parsers_tokens):
31
+ mask = torch.ones_like(pred, dtype=torch.bool)
32
+ for tokens in parsers_tokens:
33
+ mask[:, tokens] = False
34
+ pred[mask] = -float('inf')
35
+ return pred[~pred.isinf().all(dim=-1)]
36
+
37
+
38
+ def batched_inference_logits(model : GPT2LMHeadModel, input_ids : torch.Tensor, batch_size : int = 32) -> torch.Tensor:
39
+ logits = []
40
+ for i in range(0, input_ids.shape[0], batch_size):
41
+ logits.append(model(input_ids[i:i+batch_size]).logits)
42
+ return torch.cat(logits, dim=0)
43
+
44
+ def select_mask(source : list, mask : list[bool]) -> list:
45
+ assert len(source) == len(mask)
46
+ return [x for x, m in zip(source, mask) if m]
47
+
48
+
49
+ def log1mexp(x: torch.Tensor) -> torch.Tensor:
50
+ """Numerically accurate evaluation of log(1 - exp(x)) for x < 0.
51
+ See [Maechler2012accurate]_ for details.
52
+ """
53
+ mask = -math.log(2) < x # x < 0
54
+ return torch.where(
55
+ mask,
56
+ (-x.expm1()).log(),
57
+ (-x.exp()).log1p(),
58
+ )
59
+
60
+
61
+
62
+
63
+ class AcceptEverythingParser(Parser):
64
+ def __init__(self, vocab_size : int):
65
+ self.vocab_size = vocab_size
66
+ self.tokens = tuple(range(vocab_size))
67
+
68
+ def step(self, token):
69
+ pass
70
+
71
+ def next(self):
72
+ return self.tokens
73
+
74
+ def copy(self):
75
+ return self
76
+
77
+ @torch.no_grad()
78
+ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, parser : Parser, pad_token_id : int, batch_size=32, num_solutions = None) -> tuple[torch.Tensor, torch.Tensor]:
79
+ assert input_ids.shape[0] == 1, "Batch size must be 1"
80
+ device = input_ids.device
81
+ input_ids = input_ids.cpu()
82
+
83
+ if num_solutions is None:
84
+ num_solutions = beam_size
85
+ vanilla = parser is None
86
+ if vanilla:
87
+ parser = AcceptEverythingParser(model.config.vocab_size)
88
+
89
+ parsers_unfinished = [parser]
90
+ scores_finished = torch.tensor([], dtype=torch.float)
91
+ solutions_finished = torch.tensor([], dtype=torch.long).view(0,0)
92
+
93
+ input_ids_unfinished = input_ids
94
+ scores_unfinished = torch.tensor([0.0], dtype=torch.float)
95
+ solutions_unfinished = torch.tensor([], dtype=torch.long).view(1,0)
96
+
97
+
98
+ for _ in range(max_length):
99
+ if len(input_ids_unfinished) == 0:
100
+ break
101
+ pred = batched_inference_logits(model, input_ids_unfinished.to(device), batch_size)[:, -1].cpu()
102
+ parsers_tokens, can_end = get_parsers_tokens(parsers_unfinished)
103
+ logprobs = torch.log_softmax(pred, dim=-1)
104
+ logprobs_filtered = apply_mask_tokens(logprobs, parsers_tokens)
105
+ if len(logprobs_filtered):
106
+ topk = torch.topk(logprobs_filtered, beam_size, dim=-1) # shape (batch_size, beam_size)
107
+ values = topk.values + scores_unfinished.unsqueeze(-1)
108
+ topk_global = values.flatten().topk(beam_size)
109
+ best_tokens_row = topk_global.indices // beam_size
110
+ best_tokens, best_tokens_logprobs = topk.indices[best_tokens_row, topk_global.indices % beam_size], topk.values[best_tokens_row, topk_global.indices % beam_size]
111
+ notinf = ~best_tokens_logprobs.isinf()
112
+ best_tokens, best_tokens_row, best_tokens_logprobs = best_tokens[notinf], best_tokens_row[notinf], best_tokens_logprobs[notinf]
113
+ else:
114
+ best_tokens = torch.tensor([], dtype=torch.long)
115
+ best_tokens_row = torch.tensor([], dtype=torch.long)
116
+ best_tokens_logprobs = torch.tensor([], dtype=torch.float)
117
+
118
+
119
+ scores_finished_current = scores_unfinished[can_end]
120
+ solutions_finished_current = solutions_unfinished[can_end]
121
+ scores_finished_current = scores_finished_current + log1mexp(logprobs[can_end, select_mask(parsers_tokens, can_end)].logsumexp(dim=-1)).squeeze(-1)
122
+ scores_finished = torch.cat([scores_finished, scores_finished_current])
123
+ if len(solutions_finished_current):
124
+ pad = torch.full((len(scores_finished_current), solutions_finished_current.shape[1] - solutions_finished.shape[1]), pad_token_id, dtype=torch.long)
125
+ solutions_finished = torch.cat([solutions_finished.view(-1, solutions_finished_current.shape[1]+pad.shape[1]), torch.cat([solutions_finished_current, pad], dim=1)], dim=0)
126
+ if solutions_finished.numel():
127
+ # Keep num_solutions best solutions in finished
128
+ order = scores_finished.argsort(descending=True)
129
+ solutions_finished = solutions_finished[order][:num_solutions]
130
+ scores_finished = scores_finished[order][:num_solutions]
131
+
132
+
133
+ input_ids_unfinished = torch.cat([input_ids_unfinished[best_tokens_row], best_tokens.unsqueeze(-1)], dim=-1)
134
+ scores_unfinished = scores_unfinished[best_tokens_row] + best_tokens_logprobs
135
+ solutions_unfinished = torch.cat([solutions_unfinished[best_tokens_row], best_tokens.unsqueeze(-1)], dim=-1)
136
+ parsers_unfinished = [parsers_unfinished[row].copy() for row in best_tokens_row]
137
+ for parser, token in zip(parsers_unfinished, best_tokens.tolist()):
138
+ parser.step(token)
139
+
140
+ # Special case of vanilla beam search where all answers are valid
141
+ if vanilla:
142
+ order = scores_unfinished.argsort(descending=True)
143
+ scores_finished = scores_unfinished[order][:num_solutions]
144
+ solutions_finished = solutions_unfinished[order][:num_solutions]
145
+
146
+ return scores_finished, solutions_finished
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: divergent-beamsearch
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
5
5
  License-File: LICENCE
6
6
  Requires-Python: >=3.11
7
7
  Requires-Dist: multi-choices-parser>=0.9.57
8
- Requires-Dist: torch>=2.5.1
8
+ Requires-Dist: torch>=2.0.0
9
9
  Requires-Dist: transformers>=4.47.1
10
10
  Description-Content-Type: text/markdown
11
11
 
@@ -0,0 +1,6 @@
1
+ divergent_beamsearch/__init__.py,sha256=Z2R1pkj4EEHMKWVZX0upeE_Jtfb6joxgYHuRNxWc8Zo,43
2
+ divergent_beamsearch/algorithm.py,sha256=w6aLDOnLwLabmHHOMCEx1Y8P8yaHkFEJxMGNw6f7RsU,6115
3
+ divergent_beamsearch-0.1.2.dist-info/METADATA,sha256=fjqa8W8RpRSZOFV-3_D28o2353p54mB7bhEUrsRSoxw,2826
4
+ divergent_beamsearch-0.1.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ divergent_beamsearch-0.1.2.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
6
+ divergent_beamsearch-0.1.2.dist-info/RECORD,,
@@ -1,21 +1,21 @@
1
- MIT License
2
-
3
- Copyright (c) 2025 Hichem Ammar Khodja
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Hichem Ammar Khodja
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
21
  SOFTWARE.
@@ -1,6 +0,0 @@
1
- divergent_beamsearch/__init__.py,sha256=Z2R1pkj4EEHMKWVZX0upeE_Jtfb6joxgYHuRNxWc8Zo,43
2
- divergent_beamsearch/algorithm.py,sha256=6cWp6XHepSn1rjqQFkASxd8k3OEUarKNAiqPKfrA78k,5324
3
- divergent_beamsearch-0.1.0.dist-info/METADATA,sha256=UCkp3rgFZ89kmwFVy_N_dEy45NTGN_yFhf-J6WCCR4U,2826
4
- divergent_beamsearch-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- divergent_beamsearch-0.1.0.dist-info/licenses/LICENCE,sha256=jDQOOFKJxgrQwcEyipwKcKzj5IX_paD_41c3iOjH3qw,1095
6
- divergent_beamsearch-0.1.0.dist-info/RECORD,,