divergent-beamsearch 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- divergent_beamsearch/algorithm.py +21 -3
- {divergent_beamsearch-0.1.6.dist-info → divergent_beamsearch-0.1.8.dist-info}/METADATA +1 -1
- divergent_beamsearch-0.1.8.dist-info/RECORD +6 -0
- divergent_beamsearch-0.1.6.dist-info/RECORD +0 -6
- {divergent_beamsearch-0.1.6.dist-info → divergent_beamsearch-0.1.8.dist-info}/WHEEL +0 -0
- {divergent_beamsearch-0.1.6.dist-info → divergent_beamsearch-0.1.8.dist-info}/licenses/LICENCE +0 -0
@@ -76,6 +76,20 @@ class AcceptEverythingParser(Parser):
|
|
76
76
|
def copy(self):
|
77
77
|
return self
|
78
78
|
|
79
|
+
def index_reduce_lists(x : torch.Tensor, indices : list[list[int]], reduce_func=torch.sum) -> torch.Tensor:
|
80
|
+
values = []
|
81
|
+
for i, index in enumerate(indices):
|
82
|
+
values.append(reduce_func(x[i, index], dim=-1))
|
83
|
+
return torch.tensor(values, dtype=x.dtype, device=x.device, requires_grad=x.requires_grad)
|
84
|
+
|
85
|
+
def pad_to_same_size(tensors : list[torch.Tensor], padding_value : int) -> torch.Tensor:
|
86
|
+
max_size = max(x.shape[-1] for x in tensors)
|
87
|
+
padded_tensors = []
|
88
|
+
for tensor in tensors:
|
89
|
+
pad = torch.full((tensor.shape[0], max_size - tensor.shape[1]), padding_value, dtype=torch.long)
|
90
|
+
padded_tensors.append(torch.cat([tensor, pad], dim=-1))
|
91
|
+
return torch.cat(padded_tensors, dim=0)
|
92
|
+
|
79
93
|
@torch.no_grad()
|
80
94
|
def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, parser : Parser, pad_token_id : int, batch_size=32, num_solutions = None, end_symb=DEFAULT_END_SYMB) -> tuple[torch.Tensor, torch.Tensor]:
|
81
95
|
assert input_ids.shape[0] == 1, "Batch size must be 1"
|
@@ -120,11 +134,15 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
|
|
120
134
|
|
121
135
|
scores_finished_current = scores_unfinished[can_end]
|
122
136
|
solutions_finished_current = solutions_unfinished[can_end]
|
123
|
-
|
137
|
+
logprob_other_ans = index_reduce_lists(logprobs[can_end], select_mask(parsers_tokens, can_end), reduce_func=torch.logsumexp).squeeze(-1)
|
138
|
+
scores_finished_current = scores_finished_current + log1mexp(logprob_other_ans)
|
124
139
|
scores_finished = torch.cat([scores_finished, scores_finished_current])
|
125
140
|
if len(solutions_finished_current):
|
126
|
-
|
127
|
-
|
141
|
+
if len(solutions_finished):
|
142
|
+
solutions_finished = pad_to_same_size([solutions_finished, solutions_finished_current],
|
143
|
+
padding_value=pad_token_id)
|
144
|
+
else:
|
145
|
+
solutions_finished = solutions_finished_current
|
128
146
|
if solutions_finished.numel():
|
129
147
|
# Keep num_solutions best solutions in finished
|
130
148
|
order = scores_finished.argsort(descending=True)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: divergent-beamsearch
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.8
|
4
4
|
Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
|
5
5
|
License-File: LICENCE
|
6
6
|
Requires-Python: >=3.11
|
@@ -0,0 +1,6 @@
|
|
1
|
+
divergent_beamsearch/__init__.py,sha256=qrpVRoT3d-q1N9fJnzHI2X13e71LDY4-6eLOQ_gwCqQ,62
|
2
|
+
divergent_beamsearch/algorithm.py,sha256=rywmvaIoo66aksaNdCXOPfqtd8WnCazVqYoxySi6G9s,9610
|
3
|
+
divergent_beamsearch-0.1.8.dist-info/METADATA,sha256=iZjtT-uUwN1X2EfFzPI5_ermjIMu9Myz3d4H8FWR4nw,2826
|
4
|
+
divergent_beamsearch-0.1.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
+
divergent_beamsearch-0.1.8.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
|
6
|
+
divergent_beamsearch-0.1.8.dist-info/RECORD,,
|
@@ -1,6 +0,0 @@
|
|
1
|
-
divergent_beamsearch/__init__.py,sha256=qrpVRoT3d-q1N9fJnzHI2X13e71LDY4-6eLOQ_gwCqQ,62
|
2
|
-
divergent_beamsearch/algorithm.py,sha256=73BsS5IU1_4Aj11LdQebKofpCO-Mo8BXtDfx-AEYxXA,8835
|
3
|
-
divergent_beamsearch-0.1.6.dist-info/METADATA,sha256=cm-VyQfzk9sklvIkFXgEfI4A4ktWddIO5CvtaL7Vkng,2826
|
4
|
-
divergent_beamsearch-0.1.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
-
divergent_beamsearch-0.1.6.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
|
6
|
-
divergent_beamsearch-0.1.6.dist-info/RECORD,,
|
File without changes
|
{divergent_beamsearch-0.1.6.dist-info → divergent_beamsearch-0.1.8.dist-info}/licenses/LICENCE
RENAMED
File without changes
|