divergent-beamsearch 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -82,6 +82,14 @@ def index_reduce_lists(x : torch.Tensor, indices : list[list[int]], reduce_func=
82
82
  values.append(reduce_func(x[i, index], dim=-1))
83
83
  return torch.tensor(values, dtype=x.dtype, device=x.device, requires_grad=x.requires_grad)
84
84
 
85
+ def pad_to_same_size(tensors : list[torch.Tensor], padding_value : int) -> torch.Tensor:
86
+ max_size = max(x.shape[-1] for x in tensors)
87
+ padded_tensors = []
88
+ for tensor in tensors:
89
+ pad = torch.full((tensor.shape[0], max_size - tensor.shape[1]), padding_value, dtype=torch.long)
90
+ padded_tensors.append(torch.cat([tensor, pad], dim=-1))
91
+ return torch.cat(padded_tensors, dim=0)
92
+
85
93
  @torch.no_grad()
86
94
  def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam_size : int, max_length : int, parser : Parser, pad_token_id : int, batch_size=32, num_solutions = None, end_symb=DEFAULT_END_SYMB) -> tuple[torch.Tensor, torch.Tensor]:
87
95
  assert input_ids.shape[0] == 1, "Batch size must be 1"
@@ -130,8 +138,11 @@ def divergent_beamsearch(input_ids : torch.Tensor, model : GPT2LMHeadModel, beam
130
138
  scores_finished_current = scores_finished_current + log1mexp(logprob_other_ans)
131
139
  scores_finished = torch.cat([scores_finished, scores_finished_current])
132
140
  if len(solutions_finished_current):
133
- pad = torch.full((len(scores_finished_current), solutions_finished_current.shape[1] - solutions_finished.shape[1]), pad_token_id, dtype=torch.long)
134
- solutions_finished = torch.cat([solutions_finished.view(-1, solutions_finished_current.shape[1]+pad.shape[1]), torch.cat([solutions_finished_current, pad], dim=1)], dim=0)
141
+ if len(solutions_finished):
142
+ solutions_finished = pad_to_same_size([solutions_finished, solutions_finished_current],
143
+ padding_value=pad_token_id)
144
+ else:
145
+ solutions_finished = solutions_finished_current
135
146
  if solutions_finished.numel():
136
147
  # Keep num_solutions best solutions in finished
137
148
  order = scores_finished.argsort(descending=True)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: divergent-beamsearch
3
- Version: 0.1.7
3
+ Version: 0.1.8
4
4
  Summary: A variant of the beam search algorithm that focuses on finding answers that maximize the probability of generating an answer before diverging into another subject.
5
5
  License-File: LICENCE
6
6
  Requires-Python: >=3.11
@@ -0,0 +1,6 @@
1
+ divergent_beamsearch/__init__.py,sha256=qrpVRoT3d-q1N9fJnzHI2X13e71LDY4-6eLOQ_gwCqQ,62
2
+ divergent_beamsearch/algorithm.py,sha256=rywmvaIoo66aksaNdCXOPfqtd8WnCazVqYoxySi6G9s,9610
3
+ divergent_beamsearch-0.1.8.dist-info/METADATA,sha256=iZjtT-uUwN1X2EfFzPI5_ermjIMu9Myz3d4H8FWR4nw,2826
4
+ divergent_beamsearch-0.1.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ divergent_beamsearch-0.1.8.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
6
+ divergent_beamsearch-0.1.8.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- divergent_beamsearch/__init__.py,sha256=qrpVRoT3d-q1N9fJnzHI2X13e71LDY4-6eLOQ_gwCqQ,62
2
- divergent_beamsearch/algorithm.py,sha256=TUm2pbFhR0DqfGDm1fqQXqojNCAkFRmuvj4jbFCxwHc,9228
3
- divergent_beamsearch-0.1.7.dist-info/METADATA,sha256=JWuN6f2YjjOXoYxAFzR7vmVYwPyL2HDXI7huY67gAmU,2826
4
- divergent_beamsearch-0.1.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- divergent_beamsearch-0.1.7.dist-info/licenses/LICENCE,sha256=gnISbTzmuQC7NwJaGOdjoq26QYgSuKndq5q2JykifKw,1075
6
- divergent_beamsearch-0.1.7.dist-info/RECORD,,