divergent-beamsearch 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,7 +1,7 @@
1
1
  import math
2
2
  import torch
3
3
  from transformers import GPT2LMHeadModel
4
- from multi_choices_parser import MultiChoicesParser, end_symb
4
+ from multi_choices_parser import DEFAULT_END_SYMB
5
5
 
6
6
 
7
7
  class Parser:
@@ -19,10 +19,10 @@ def get_parsers_tokens(parsers : list[Parser], end_symb) -> tuple[list, list[int
19
19
  can_end = []
20
20
  for parser in parsers:
21
21
  tokens = list(parser.next())
22
- if end_symb in tokens:
23
- can_end.append(True)
22
+ try:
24
23
  tokens.remove(end_symb)
25
- else:
24
+ can_end.append(True)
25
+ except ValueError:
26
26
  can_end.append(False)
27
27
  parsers_tokens.append(tokens)
28
28
  return parsers_tokens, can_end
@@ -76,8 +76,14 @@ class AcceptEverythingParser(Parser):
76
76
  def copy(self):
77
77
  return self
78
78
 
79
+ def index_reduce_lists(x : torch.Tensor, indices : list[list[int]], reduce_func=torch.sum) -> torch.Tensor:
80
+ values = []
81
+ for i, index in enumerate(indices):
82
+ values.append(reduce_func(x[i, index], dim=-1))
83
+ return torch.tensor(values, dtype=x.dtype, device=x.device, requires_grad=x.requires_grad)
84
+
79
85
  @torch.no_grad()
80
- def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, parser : Parser, pad_token_id : int, batch_size=32, num_solutions = None, end_symb=end_symb) -> tuple[torch.Tensor, torch.Tensor]:
86
+ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, parser : Parser, pad_token_id : int, batch_size=32, num_solutions = None, end_symb=DEFAULT_END_SYMB) -> tuple[torch.Tensor, torch.Tensor]:
81
87
  assert input_ids.shape[0] == 1, "Batch size must be 1"
82
88
  device = input_ids.device
83
89
  input_ids = input_ids.cpu()
@@ -120,7 +126,8 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
120
126
 
121
127
  scores_finished_current = scores_unfinished[can_end]
122
128
  solutions_finished_current = solutions_unfinished[can_end]
123
- scores_finished_current = scores_finished_current + log1mexp(logprobs[can_end, select_mask(parsers_tokens, can_end)].logsumexp(dim=-1)).squeeze(-1)
129
+ logprob_other_ans = index_reduce_lists(logprobs[can_end], select_mask(parsers_tokens, can_end), reduce_func=torch.logsumexp).squeeze(-1)
130
+ scores_finished_current = scores_finished_current + log1mexp(logprob_other_ans)
124
131
  scores_finished = torch.cat([scores_finished, scores_finished_current])
125
132
  if len(solutions_finished_current):
126
133
  pad = torch.full((len(scores_finished_current), solutions_finished_current.shape[1] - solutions_finished.shape[1]), pad_token_id, dtype=torch.long)
@@ -140,6 +147,7 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
140
147
  parser.step(token)
141
148
 
142
149
  # Special case of vanilla beam search where all answers are valid
150
+ # Warning : In this case model will not stop on end_of_sentence token
143
151
  if vanilla:
144
152
  order = scores_unfinished.argsort(descending=True)
145
153
  scores_finished = scores_unfinished[order][:num_solutions]
@@ -154,7 +162,9 @@ def set_slice_row(x : torch.Tensor, slices : torch.IntTensor, value) -> torch.Te
154
162
  x[i].index_fill_(0, indices[i], 0)
155
163
 
156
164
  @torch.no_grad()
157
- def divergent_logprob(input_ids : torch.Tensor, attention_mask : torch.Tensor | None, model : GPT2LMHeadModel, parsers : Parser | list[Parser] | None, batch_size=32, start : int | torch.IntTensor = None) -> torch.FloatTensor:
165
+ def divergent_logprob(input_ids : torch.Tensor, attention_mask : torch.Tensor | None, model : GPT2LMHeadModel,
166
+ parsers : Parser | list[Parser] | None, batch_size=32,
167
+ start : int | torch.IntTensor = None, end_symb=DEFAULT_END_SYMB) -> torch.FloatTensor:
158
168
  if start is None:
159
169
  start = 0
160
170
  if isinstance(start, int):
@@ -1,10 +1,10 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: divergent-beamsearch
3
- Version: 0.1.5
3
+ Version: 0.1.7
4
4
  Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
5
5
  License-File: LICENCE
6
6
  Requires-Python: >=3.11
7
- Requires-Dist: multi-choices-parser>=0.9.57
7
+ Requires-Dist: multi-choices-parser>=0.9.61
8
8
  Requires-Dist: torch>=2.0.0
9
9
  Requires-Dist: transformers>=4.47.1
10
10
  Description-Content-Type: text/markdown
@@ -0,0 +1,6 @@
1
+ divergent_beamsearch/__init__.py,sha256=qrpVRoT3d-q1N9fJnzHI2X13e71LDY4-6eLOQ_gwCqQ,62
2
+ divergent_beamsearch/algorithm.py,sha256=TUm2pbFhR0DqfGDm1fqQXqojNCAkFRmuvj4jbFCxwHc,9228
3
+ divergent_beamsearch-0.1.7.dist-info/METADATA,sha256=JWuN6f2YjjOXoYxAFzR7vmVYwPyL2HDXI7huY67gAmU,2826
4
+ divergent_beamsearch-0.1.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ divergent_beamsearch-0.1.7.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
6
+ divergent_beamsearch-0.1.7.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- divergent_beamsearch/__init__.py,sha256=qrpVRoT3d-q1N9fJnzHI2X13e71LDY4-6eLOQ_gwCqQ,62
2
- divergent_beamsearch/algorithm.py,sha256=b44kA2_M0HfOSC6LYQmu7W_JnGAl1u9Sz_91jMjvWg0,8688
3
- divergent_beamsearch-0.1.5.dist-info/METADATA,sha256=sRitbfgDp8YqLXZzZmZw7Nd1pJWv8T4HJoSqjhO4ztI,2826
4
- divergent_beamsearch-0.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- divergent_beamsearch-0.1.5.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
6
- divergent_beamsearch-0.1.5.dist-info/RECORD,,