divergent-beamsearch 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,7 @@
1
1
  import math
2
2
  import torch
3
3
  from transformers import GPT2LMHeadModel
4
- from multi_choices_parser import MultiChoicesParser, end_symb
4
+ from multi_choices_parser import DEFAULT_END_SYMB
5
5
 
6
6
 
7
7
  class Parser:
@@ -19,10 +19,10 @@ def get_parsers_tokens(parsers : list[Parser], end_symb) -> tuple[list, list[int
19
19
  can_end = []
20
20
  for parser in parsers:
21
21
  tokens = list(parser.next())
22
- if end_symb in tokens:
23
- can_end.append(True)
22
+ try:
24
23
  tokens.remove(end_symb)
25
- else:
24
+ can_end.append(True)
25
+ except ValueError:
26
26
  can_end.append(False)
27
27
  parsers_tokens.append(tokens)
28
28
  return parsers_tokens, can_end
@@ -76,8 +76,14 @@ class AcceptEverythingParser(Parser):
76
76
  def copy(self):
77
77
  return self
78
78
 
79
+ def index_reduce_lists(x : torch.Tensor, indices : list[list[int]], reduce_func=torch.sum) -> torch.Tensor:
80
+ values = []
81
+ for i, index in enumerate(indices):
82
+ values.append(reduce_func(x[i, index], dim=-1))
83
+ return torch.tensor(values, dtype=x.dtype, device=x.device, requires_grad=x.requires_grad)
84
+
79
85
  @torch.no_grad()
80
- def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, parser : Parser, pad_token_id : int, batch_size=32, num_solutions = None, end_symb=end_symb) -> tuple[torch.Tensor, torch.Tensor]:
86
+ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, parser : Parser, pad_token_id : int, batch_size=32, num_solutions = None, end_symb=DEFAULT_END_SYMB) -> tuple[torch.Tensor, torch.Tensor]:
81
87
  assert input_ids.shape[0] == 1, "Batch size must be 1"
82
88
  device = input_ids.device
83
89
  input_ids = input_ids.cpu()
@@ -120,7 +126,8 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
120
126
 
121
127
  scores_finished_current = scores_unfinished[can_end]
122
128
  solutions_finished_current = solutions_unfinished[can_end]
123
- scores_finished_current = scores_finished_current + log1mexp(logprobs[can_end, select_mask(parsers_tokens, can_end)].logsumexp(dim=-1)).squeeze(-1)
129
+ logprob_other_ans = index_reduce_lists(logprobs[can_end], select_mask(parsers_tokens, can_end), reduce_func=torch.logsumexp).squeeze(-1)
130
+ scores_finished_current = scores_finished_current + log1mexp(logprob_other_ans)
124
131
  scores_finished = torch.cat([scores_finished, scores_finished_current])
125
132
  if len(solutions_finished_current):
126
133
  pad = torch.full((len(scores_finished_current), solutions_finished_current.shape[1] - solutions_finished.shape[1]), pad_token_id, dtype=torch.long)
@@ -140,6 +147,7 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
140
147
  parser.step(token)
141
148
 
142
149
  # Special case of vanilla beam search where all answers are valid
150
+ # Warning : In this case model will not stop on end_of_sentence token
143
151
  if vanilla:
144
152
  order = scores_unfinished.argsort(descending=True)
145
153
  scores_finished = scores_unfinished[order][:num_solutions]
@@ -154,7 +162,9 @@ def set_slice_row(x : torch.Tensor, slices : torch.IntTensor, value) -> torch.Te
154
162
  x[i].index_fill_(0, indices[i], 0)
155
163
 
156
164
  @torch.no_grad()
157
- def divergent_logprob(input_ids : torch.Tensor, attention_mask : torch.Tensor | None, model : GPT2LMHeadModel, parsers : Parser | list[Parser] | None, batch_size=32, start : int | torch.IntTensor = None) -> torch.FloatTensor:
165
+ def divergent_logprob(input_ids : torch.Tensor, attention_mask : torch.Tensor | None, model : GPT2LMHeadModel,
166
+ parsers : Parser | list[Parser] | None, batch_size=32,
167
+ start : int | torch.IntTensor = None, end_symb=DEFAULT_END_SYMB) -> torch.FloatTensor:
158
168
  if start is None:
159
169
  start = 0
160
170
  if isinstance(start, int):
@@ -1,10 +1,10 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: divergent-beamsearch
3
- Version: 0.1.5
3
+ Version: 0.1.7
4
4
  Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
5
5
  License-File: LICENCE
6
6
  Requires-Python: >=3.11
7
- Requires-Dist: multi-choices-parser>=0.9.57
7
+ Requires-Dist: multi-choices-parser>=0.9.61
8
8
  Requires-Dist: torch>=2.0.0
9
9
  Requires-Dist: transformers>=4.47.1
10
10
  Description-Content-Type: text/markdown
@@ -0,0 +1,6 @@
1
+ divergent_beamsearch/__init__.py,sha256=qrpVRoT3d-q1N9fJnzHI2X13e71LDY4-6eLOQ_gwCqQ,62
2
+ divergent_beamsearch/algorithm.py,sha256=TUm2pbFhR0DqfGDm1fqQXqojNCAkFRmuvj4jbFCxwHc,9228
3
+ divergent_beamsearch-0.1.7.dist-info/METADATA,sha256=JWuN6f2YjjOXoYxAFzR7vmVYwPyL2HDXI7huY67gAmU,2826
4
+ divergent_beamsearch-0.1.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ divergent_beamsearch-0.1.7.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
6
+ divergent_beamsearch-0.1.7.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- divergent_beamsearch/__init__.py,sha256=qrpVRoT3d-q1N9fJnzHI2X13e71LDY4-6eLOQ_gwCqQ,62
2
- divergent_beamsearch/algorithm.py,sha256=b44kA2_M0HfOSC6LYQmu7W_JnGAl1u9Sz_91jMjvWg0,8688
3
- divergent_beamsearch-0.1.5.dist-info/METADATA,sha256=sRitbfgDp8YqLXZzZmZw7Nd1pJWv8T4HJoSqjhO4ztI,2826
4
- divergent_beamsearch-0.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- divergent_beamsearch-0.1.5.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
6
- divergent_beamsearch-0.1.5.dist-info/RECORD,,