divergent-beamsearch 0.1.8__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- divergent_beamsearch/algorithm.py +15 -7
- {divergent_beamsearch-0.1.8.dist-info → divergent_beamsearch-0.2.0.dist-info}/METADATA +1 -1
- divergent_beamsearch-0.2.0.dist-info/RECORD +6 -0
- divergent_beamsearch-0.1.8.dist-info/RECORD +0 -6
- {divergent_beamsearch-0.1.8.dist-info → divergent_beamsearch-0.2.0.dist-info}/WHEEL +0 -0
- {divergent_beamsearch-0.1.8.dist-info → divergent_beamsearch-0.2.0.dist-info}/licenses/LICENCE +0 -0
@@ -35,12 +35,17 @@ def apply_mask_tokens(pred : torch.Tensor, parsers_tokens):
|
|
35
35
|
return pred[~pred.isinf().all(dim=-1)]
|
36
36
|
|
37
37
|
|
38
|
-
def batched_inference_logits(model : GPT2LMHeadModel, input_ids : torch.Tensor,
|
38
|
+
def batched_inference_logits(model : GPT2LMHeadModel, input_ids : torch.Tensor,
|
39
|
+
attention_mask : torch.Tensor | None = None, batch_size : int = 32,
|
40
|
+
to_cpu=False) -> torch.Tensor:
|
39
41
|
logits = []
|
40
42
|
if attention_mask is None:
|
41
43
|
attention_mask = torch.ones_like(input_ids)
|
42
44
|
for i in range(0, input_ids.shape[0], batch_size):
|
43
|
-
|
45
|
+
l = model(input_ids[i:i+batch_size], attention_mask=attention_mask[i:i+batch_size]).logits
|
46
|
+
if to_cpu:
|
47
|
+
l = l.cpu()
|
48
|
+
logits.append(l)
|
44
49
|
return torch.cat(logits, dim=0)
|
45
50
|
|
46
51
|
def select_mask(source : list, mask : list[bool]) -> list:
|
@@ -91,7 +96,9 @@ def pad_to_same_size(tensors : list[torch.Tensor], padding_value : int) -> torch
|
|
91
96
|
return torch.cat(padded_tensors, dim=0)
|
92
97
|
|
93
98
|
@torch.no_grad()
|
94
|
-
def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int,
|
99
|
+
def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int,
|
100
|
+
max_length : int, parser : Parser, pad_token_id : int, batch_size=32,
|
101
|
+
num_solutions = None, end_symb=DEFAULT_END_SYMB, optimize_gpu_mem=True) -> tuple[torch.Tensor, torch.Tensor]:
|
95
102
|
assert input_ids.shape[0] == 1, "Batch size must be 1"
|
96
103
|
device = input_ids.device
|
97
104
|
input_ids = input_ids.cpu()
|
@@ -114,7 +121,7 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
|
|
114
121
|
for _ in range(max_length):
|
115
122
|
if len(input_ids_unfinished) == 0:
|
116
123
|
break
|
117
|
-
pred = batched_inference_logits(model, input_ids_unfinished.to(device), batch_size=batch_size)[:, -1].cpu()
|
124
|
+
pred = batched_inference_logits(model, input_ids_unfinished.to(device), batch_size=batch_size, to_cpu=optimize_gpu_mem)[:, -1].cpu()
|
118
125
|
parsers_tokens, can_end = get_parsers_tokens(parsers_unfinished, end_symb)
|
119
126
|
logprobs = torch.log_softmax(pred, dim=-1)
|
120
127
|
logprobs_filtered = apply_mask_tokens(logprobs, parsers_tokens)
|
@@ -175,19 +182,20 @@ def set_slice_row(x : torch.Tensor, slices : torch.IntTensor, value) -> torch.Te
|
|
175
182
|
@torch.no_grad()
|
176
183
|
def divergent_logprob(input_ids : torch.Tensor, attention_mask : torch.Tensor | None, model : GPT2LMHeadModel,
|
177
184
|
parsers : Parser | list[Parser] | None, batch_size=32,
|
178
|
-
start : int | torch.IntTensor = None, end_symb=DEFAULT_END_SYMB) -> torch.FloatTensor:
|
185
|
+
start : int | torch.IntTensor = None, end_symb=DEFAULT_END_SYMB, optimize_gpu_mem=True) -> torch.FloatTensor:
|
179
186
|
if start is None:
|
180
|
-
start =
|
187
|
+
start = 1
|
181
188
|
if isinstance(start, int):
|
182
189
|
start = torch.tensor([start]*input_ids.shape[0])
|
183
190
|
assert start.shape[0] == input_ids.shape[0]
|
191
|
+
assert (start > 0).all()
|
184
192
|
# -1 because next token offset
|
185
193
|
start = start - 1
|
186
194
|
|
187
195
|
if attention_mask is None:
|
188
196
|
attention_mask = torch.ones_like(input_ids)
|
189
197
|
|
190
|
-
logits = batched_inference_logits(model, input_ids, attention_mask, batch_size).cpu()
|
198
|
+
logits = batched_inference_logits(model, input_ids, attention_mask, batch_size, to_cpu=optimize_gpu_mem).cpu()
|
191
199
|
input_ids = input_ids.cpu()
|
192
200
|
attention_mask = attention_mask.cpu()
|
193
201
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: divergent-beamsearch
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2.0
|
4
4
|
Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
|
5
5
|
License-File: LICENCE
|
6
6
|
Requires-Python: >=3.11
|
@@ -0,0 +1,6 @@
|
|
1
|
+
divergent_beamsearch/__init__.py,sha256=qrpVRoT3d-q1N9fJnzHI2X13e71LDY4-6eLOQ_gwCqQ,62
|
2
|
+
divergent_beamsearch/algorithm.py,sha256=lx27rXddHiyzisINgWI5MuatRLIU2ObnZhtCvojbGJ8,9917
|
3
|
+
divergent_beamsearch-0.2.0.dist-info/METADATA,sha256=u4-bH-9qa_yLJPemATIemwIavOCucF7CCv0kyJV6_Qg,2826
|
4
|
+
divergent_beamsearch-0.2.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
+
divergent_beamsearch-0.2.0.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
|
6
|
+
divergent_beamsearch-0.2.0.dist-info/RECORD,,
|
@@ -1,6 +0,0 @@
|
|
1
|
-
divergent_beamsearch/__init__.py,sha256=qrpVRoT3d-q1N9fJnzHI2X13e71LDY4-6eLOQ_gwCqQ,62
|
2
|
-
divergent_beamsearch/algorithm.py,sha256=rywmvaIoo66aksaNdCXOPfqtd8WnCazVqYoxySi6G9s,9610
|
3
|
-
divergent_beamsearch-0.1.8.dist-info/METADATA,sha256=iZjtT-uUwN1X2EfFzPI5_ermjIMu9Myz3d4H8FWR4nw,2826
|
4
|
-
divergent_beamsearch-0.1.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
-
divergent_beamsearch-0.1.8.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
|
6
|
-
divergent_beamsearch-0.1.8.dist-info/RECORD,,
|
File without changes
|
{divergent_beamsearch-0.1.8.dist-info → divergent_beamsearch-0.2.0.dist-info}/licenses/LICENCE
RENAMED
File without changes
|