divergent-beamsearch 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,7 @@
1
1
  import math
2
2
  import torch
3
3
  from transformers import GPT2LMHeadModel
4
- from multi_choices_parser import MultiChoicesParser, end_symb
4
+ from multi_choices_parser import DEFAULT_END_SYMB
5
5
 
6
6
 
7
7
  class Parser:
@@ -19,10 +19,10 @@ def get_parsers_tokens(parsers : list[Parser], end_symb) -> tuple[list, list[int
19
19
  can_end = []
20
20
  for parser in parsers:
21
21
  tokens = list(parser.next())
22
- if end_symb in tokens:
23
- can_end.append(True)
22
+ try:
24
23
  tokens.remove(end_symb)
25
- else:
24
+ can_end.append(True)
25
+ except ValueError:
26
26
  can_end.append(False)
27
27
  parsers_tokens.append(tokens)
28
28
  return parsers_tokens, can_end
@@ -77,7 +77,7 @@ class AcceptEverythingParser(Parser):
77
77
  return self
78
78
 
79
79
  @torch.no_grad()
80
- def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, parser : Parser, pad_token_id : int, batch_size=32, num_solutions = None, end_symb=end_symb) -> tuple[torch.Tensor, torch.Tensor]:
80
+ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, parser : Parser, pad_token_id : int, batch_size=32, num_solutions = None, end_symb=DEFAULT_END_SYMB) -> tuple[torch.Tensor, torch.Tensor]:
81
81
  assert input_ids.shape[0] == 1, "Batch size must be 1"
82
82
  device = input_ids.device
83
83
  input_ids = input_ids.cpu()
@@ -140,6 +140,7 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
140
140
  parser.step(token)
141
141
 
142
142
  # Special case of vanilla beam search where all answers are valid
143
+ # Warning : In this case model will not stop on end_of_sentence token
143
144
  if vanilla:
144
145
  order = scores_unfinished.argsort(descending=True)
145
146
  scores_finished = scores_unfinished[order][:num_solutions]
@@ -154,7 +155,9 @@ def set_slice_row(x : torch.Tensor, slices : torch.IntTensor, value) -> torch.Te
154
155
  x[i].index_fill_(0, indices[i], 0)
155
156
 
156
157
  @torch.no_grad()
157
- def divergent_logprob(input_ids : torch.Tensor, attention_mask : torch.Tensor | None, model : GPT2LMHeadModel, parsers : Parser | list[Parser] | None, batch_size=32, start : int | torch.IntTensor = None) -> torch.FloatTensor:
158
+ def divergent_logprob(input_ids : torch.Tensor, attention_mask : torch.Tensor | None, model : GPT2LMHeadModel,
159
+ parsers : Parser | list[Parser] | None, batch_size=32,
160
+ start : int | torch.IntTensor = None, end_symb=DEFAULT_END_SYMB) -> torch.FloatTensor:
158
161
  if start is None:
159
162
  start = 0
160
163
  if isinstance(start, int):
@@ -1,10 +1,10 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: divergent-beamsearch
3
- Version: 0.1.5
3
+ Version: 0.1.6
4
4
  Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
5
5
  License-File: LICENCE
6
6
  Requires-Python: >=3.11
7
- Requires-Dist: multi-choices-parser>=0.9.57
7
+ Requires-Dist: multi-choices-parser>=0.9.61
8
8
  Requires-Dist: torch>=2.0.0
9
9
  Requires-Dist: transformers>=4.47.1
10
10
  Description-Content-Type: text/markdown
@@ -0,0 +1,6 @@
1
+ divergent_beamsearch/__init__.py,sha256=qrpVRoT3d-q1N9fJnzHI2X13e71LDY4-6eLOQ_gwCqQ,62
2
+ divergent_beamsearch/algorithm.py,sha256=73BsS5IU1_4Aj11LdQebKofpCO-Mo8BXtDfx-AEYxXA,8835
3
+ divergent_beamsearch-0.1.6.dist-info/METADATA,sha256=cm-VyQfzk9sklvIkFXgEfI4A4ktWddIO5CvtaL7Vkng,2826
4
+ divergent_beamsearch-0.1.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ divergent_beamsearch-0.1.6.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
6
+ divergent_beamsearch-0.1.6.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- divergent_beamsearch/__init__.py,sha256=qrpVRoT3d-q1N9fJnzHI2X13e71LDY4-6eLOQ_gwCqQ,62
2
- divergent_beamsearch/algorithm.py,sha256=b44kA2_M0HfOSC6LYQmu7W_JnGAl1u9Sz_91jMjvWg0,8688
3
- divergent_beamsearch-0.1.5.dist-info/METADATA,sha256=sRitbfgDp8YqLXZzZmZw7Nd1pJWv8T4HJoSqjhO4ztI,2826
4
- divergent_beamsearch-0.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- divergent_beamsearch-0.1.5.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
6
- divergent_beamsearch-0.1.5.dist-info/RECORD,,