TruthTorchLM 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- TruthTorchLM/__init__.py +16 -0
- TruthTorchLM/availability.py +14 -0
- TruthTorchLM/calibration.py +36 -0
- TruthTorchLM/evaluators/__init__.py +8 -0
- TruthTorchLM/evaluators/bleu.py +20 -0
- TruthTorchLM/evaluators/correctness_evaluator.py +14 -0
- TruthTorchLM/evaluators/eval_truth_method.py +59 -0
- TruthTorchLM/evaluators/model_judge.py +61 -0
- TruthTorchLM/evaluators/rouge.py +19 -0
- TruthTorchLM/generation.py +389 -0
- TruthTorchLM/long_form_generation/__init__.py +5 -0
- TruthTorchLM/long_form_generation/decomposition_methods/__init__.py +8 -0
- TruthTorchLM/long_form_generation/decomposition_methods/decomposition_method.py +27 -0
- TruthTorchLM/long_form_generation/decomposition_methods/structured_decomposition_api.py +50 -0
- TruthTorchLM/long_form_generation/decomposition_methods/structured_decomposition_local.py +43 -0
- TruthTorchLM/long_form_generation/decomposition_methods/unstructured_decomposition_api.py +50 -0
- TruthTorchLM/long_form_generation/decomposition_methods/unstructured_decomposition_local.py +65 -0
- TruthTorchLM/long_form_generation/evaluators/__init__.py +4 -0
- TruthTorchLM/long_form_generation/evaluators/eval_claim.py +223 -0
- TruthTorchLM/long_form_generation/evaluators/long_gen_eval.py +158 -0
- TruthTorchLM/long_form_generation/generation.py +167 -0
- TruthTorchLM/long_form_generation/statement_check_methods/__init__.py +7 -0
- TruthTorchLM/long_form_generation/statement_check_methods/answer_statement_entailment.py +219 -0
- TruthTorchLM/long_form_generation/statement_check_methods/question_answer_generation.py +354 -0
- TruthTorchLM/long_form_generation/statement_check_methods/question_generation.py +293 -0
- TruthTorchLM/long_form_generation/statement_check_methods/statement_check_method.py +46 -0
- TruthTorchLM/long_form_generation/utils/__init__.py +3 -0
- TruthTorchLM/long_form_generation/utils/dataset_utils.py +90 -0
- TruthTorchLM/long_form_generation/utils/eval_utils.py +188 -0
- TruthTorchLM/long_form_generation/utils/safe_utils.py +231 -0
- TruthTorchLM/normalizers/__init__.py +4 -0
- TruthTorchLM/normalizers/normalizer.py +36 -0
- TruthTorchLM/normalizers/sigmoid_normalizer.py +34 -0
- TruthTorchLM/scoring_methods/__init__.py +5 -0
- TruthTorchLM/scoring_methods/length_normalized_scoring.py +12 -0
- TruthTorchLM/scoring_methods/log_prob_scoring.py +11 -0
- TruthTorchLM/scoring_methods/scoring_method.py +19 -0
- TruthTorchLM/templates.py +169 -0
- TruthTorchLM/truth_methods/__init__.py +31 -0
- TruthTorchLM/truth_methods/attention_score.py +52 -0
- TruthTorchLM/truth_methods/confidence.py +59 -0
- TruthTorchLM/truth_methods/cross_examination.py +164 -0
- TruthTorchLM/truth_methods/eccentricity_confidence.py +74 -0
- TruthTorchLM/truth_methods/eccentricity_uncertainty.py +69 -0
- TruthTorchLM/truth_methods/entropy.py +66 -0
- TruthTorchLM/truth_methods/google_search_check.py +144 -0
- TruthTorchLM/truth_methods/inside.py +49 -0
- TruthTorchLM/truth_methods/kernel_language_entropy.py +81 -0
- TruthTorchLM/truth_methods/lars.py +479 -0
- TruthTorchLM/truth_methods/mars.py +196 -0
- TruthTorchLM/truth_methods/matrix_degree_confidence.py +78 -0
- TruthTorchLM/truth_methods/matrix_degree_uncertainty.py +74 -0
- TruthTorchLM/truth_methods/multi_llm_collab.py +535 -0
- TruthTorchLM/truth_methods/num_semantic_set_uncertainty.py +70 -0
- TruthTorchLM/truth_methods/p_true.py +71 -0
- TruthTorchLM/truth_methods/saplma.py +206 -0
- TruthTorchLM/truth_methods/self_detection.py +133 -0
- TruthTorchLM/truth_methods/semantic_entropy.py +93 -0
- TruthTorchLM/truth_methods/sentSAR.py +101 -0
- TruthTorchLM/truth_methods/sum_eigen_uncertainty.py +71 -0
- TruthTorchLM/truth_methods/tokenSAR.py +76 -0
- TruthTorchLM/truth_methods/truth_method.py +73 -0
- TruthTorchLM/truth_methods/verbalized_confidence.py +77 -0
- TruthTorchLM/utils/__init__.py +5 -0
- TruthTorchLM/utils/calibration_utils.py +64 -0
- TruthTorchLM/utils/common_utils.py +374 -0
- TruthTorchLM/utils/dataset_utils.py +127 -0
- TruthTorchLM/utils/eval_utils.py +280 -0
- TruthTorchLM/utils/google_search_utils.py +136 -0
- truthtorchlm-0.0.0.dist-info/LICENSE +21 -0
- truthtorchlm-0.0.0.dist-info/LICENSE copy +21 -0
- truthtorchlm-0.0.0.dist-info/METADATA +209 -0
- truthtorchlm-0.0.0.dist-info/RECORD +75 -0
- truthtorchlm-0.0.0.dist-info/WHEEL +5 -0
- truthtorchlm-0.0.0.dist-info/top_level.txt +1 -0
TruthTorchLM/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from .truth_methods.truth_method import TruthMethod
|
|
2
|
+
from TruthTorchLM import utils ##TODO do we really need to import this?
|
|
3
|
+
from TruthTorchLM import scoring_methods
|
|
4
|
+
from TruthTorchLM import truth_methods
|
|
5
|
+
from .generation import generate_with_truth_value
|
|
6
|
+
from .calibration import calibrate_truth_method
|
|
7
|
+
from TruthTorchLM import evaluators
|
|
8
|
+
from .evaluators import evaluate_truth_method
|
|
9
|
+
from .templates import DEFAULT_USER_PROMPT, DEFAULT_SYSTEM_PROMPT ##TODO import all?
|
|
10
|
+
from .availability import AVAILABLE_DATASETS, AVAILABLE_EVALUATION_METRICS
|
|
11
|
+
from TruthTorchLM import normalizers
|
|
12
|
+
|
|
13
|
+
from TruthTorchLM import long_form_generation
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
#__all__ = ['generate_with_truth_value']
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
AVAILABLE_API_MODELS = ['gpt-4o', 'gpt-4o-2024-05-13', 'gpt-4o-2024-08-06', 'chatgpt-4o-latest', 'gpt-4o-mini', 'gpt-4o-mini-2024-07-18',
|
|
2
|
+
'gpt-4-turbo','gpt-4-turbo-2024-04-09', 'gpt-4-turbo-preview', 'gpt-4-0125-preview', 'gpt-4-1106-preview', 'gpt-4',
|
|
3
|
+
'gpt-4-0613', 'gpt-4-0314', 'gpt-3.5-turbo-0125', 'gpt-3.5-turbo', 'gpt-3.5-turbo-1106', 'gpt-3.5-turbo-instruct', 'together_ai/togethercomputer/llama-2-70b']
|
|
4
|
+
|
|
5
|
+
PROB_AVAILABLE_API_MODELS = ['gpt-4o', 'gpt-4o-2024-05-13', 'gpt-4o-2024-08-06', 'chatgpt-4o-latest', 'gpt-4o-mini', 'gpt-4o-mini-2024-07-18',
|
|
6
|
+
'gpt-4-turbo','gpt-4-turbo-2024-04-09', 'gpt-4-turbo-preview', 'gpt-4-0125-preview', 'gpt-4-1106-preview', 'gpt-4',
|
|
7
|
+
'gpt-4-0613', 'gpt-4-0314', 'gpt-3.5-turbo-0125', 'gpt-3.5-turbo', 'gpt-3.5-turbo-1106', 'gpt-3.5-turbo-instruct', 'together_ai/togethercomputer/llama-2-70b']
|
|
8
|
+
|
|
9
|
+
ACTIVATION_AVAILABLE_API_MODELS = []
|
|
10
|
+
|
|
11
|
+
AVAILABLE_DATASETS = ['trivia_qa', 'gsm8k', 'natural_qa', 'pop_qa', 'simple_qa']
|
|
12
|
+
LONG_FORM_AVAILABLE_DATASETS = ['longfact_concepts', 'longfact_objects']
|
|
13
|
+
|
|
14
|
+
AVAILABLE_EVALUATION_METRICS = ['auroc', 'auprc', 'auarc', 'accuracy', 'f1', 'precision', 'recall', 'prr']
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
|
2
|
+
from typing import Union
|
|
3
|
+
from TruthTorchLM.truth_methods import TruthMethod
|
|
4
|
+
from TruthTorchLM.evaluators import CorrectnessEvaluator, ROUGE
|
|
5
|
+
from TruthTorchLM.templates import DEFAULT_SYSTEM_BENCHMARK_PROMPT, DEFAULT_USER_PROMPT
|
|
6
|
+
from TruthTorchLM.utils.dataset_utils import get_dataset
|
|
7
|
+
from TruthTorchLM.utils.eval_utils import run_over_dataset
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def calibrate_truth_method(dataset: Union[str, list], model:Union[str,PreTrainedModel], truth_methods: list[TruthMethod], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] =None,
|
|
13
|
+
correctness_evaluator:CorrectnessEvaluator = ROUGE(0.7), size_of_data:float = 1.0, previous_context:list =[{'role': 'system', 'content': DEFAULT_SYSTEM_BENCHMARK_PROMPT}],
|
|
14
|
+
user_prompt:str = DEFAULT_USER_PROMPT, seed:int = 0, wandb_run = None, return_method_details:bool = False, wandb_push_method_details:bool = False, split = 'train', **kwargs):
|
|
15
|
+
|
|
16
|
+
dataset = get_dataset(dataset, size_of_data=size_of_data, seed=seed, split = split)
|
|
17
|
+
|
|
18
|
+
output_dict = run_over_dataset(dataset, model, truth_methods, tokenizer = tokenizer, correctness_evaluator = correctness_evaluator,
|
|
19
|
+
previous_context = previous_context, user_prompt = user_prompt, seed = seed, return_method_details = return_method_details,
|
|
20
|
+
wandb_run = wandb_run, wandb_push_method_details = wandb_push_method_details, **kwargs)
|
|
21
|
+
|
|
22
|
+
for i, truth_method in enumerate(truth_methods):
|
|
23
|
+
truth_values = output_dict[f'truth_method_{i}']['truth_values']
|
|
24
|
+
truth_values = np.array(truth_values)
|
|
25
|
+
truth_values[np.isnan(truth_values)] = 0
|
|
26
|
+
correctness = output_dict['generation_correctness']
|
|
27
|
+
#if generation_correctness is -1, it means that the model didn't attempt to generate an answer, remove those from the evaluation
|
|
28
|
+
truth_method.normalizer.calibrate(generation_performance_scores=correctness, truth_values=truth_values)
|
|
29
|
+
return output_dict
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from .correctness_evaluator import CorrectnessEvaluator
|
|
2
|
+
from .rouge import ROUGE
|
|
3
|
+
from .bleu import BLEU
|
|
4
|
+
from .model_judge import ModelJudge
|
|
5
|
+
from .eval_truth_method import evaluate_truth_method, get_metric_scores
|
|
6
|
+
|
|
7
|
+
__all__ = ['CorrectnessEvaluator', 'ROUGE', 'BLEU', 'evaluate_truth_method', 'ModelJudge', 'get_metric_scores']
|
|
8
|
+
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
|
|
2
|
+
from .correctness_evaluator import CorrectnessEvaluator
|
|
3
|
+
import evaluate
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class BLEU(CorrectnessEvaluator):
|
|
7
|
+
def __init__(self, threshold: float = 0.5):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.threshold = threshold
|
|
10
|
+
self.bleu = evaluate.load('bleu')
|
|
11
|
+
|
|
12
|
+
def __call__(self, question_text:str, generated_text: str, ground_truths: list[str], seed:int = None) -> bool:
|
|
13
|
+
for i in range(len(ground_truths)):
|
|
14
|
+
bleu_results = self.bleu.compute(predictions = [generated_text], references=[ground_truths[i]])
|
|
15
|
+
if bleu_results['bleu'] > self.threshold:
|
|
16
|
+
return 1
|
|
17
|
+
return 0
|
|
18
|
+
|
|
19
|
+
def __str__(self):
|
|
20
|
+
return f"BLEU with threshold {self.threshold}"
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
class CorrectnessEvaluator(ABC):
|
|
4
|
+
|
|
5
|
+
def __init__(self):
|
|
6
|
+
pass
|
|
7
|
+
|
|
8
|
+
@abstractmethod
|
|
9
|
+
def __call__(self, question_text:str, generated_text: str, ground_truth_text: list[str], seed:int = None) -> int:
|
|
10
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
11
|
+
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def __str__(self):
|
|
14
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
|
2
|
+
from typing import Union
|
|
3
|
+
from TruthTorchLM.truth_methods import TruthMethod
|
|
4
|
+
from .correctness_evaluator import CorrectnessEvaluator
|
|
5
|
+
from .rouge import ROUGE
|
|
6
|
+
from TruthTorchLM.availability import AVAILABLE_EVALUATION_METRICS
|
|
7
|
+
from TruthTorchLM.templates import DEFAULT_SYSTEM_BENCHMARK_PROMPT, DEFAULT_USER_PROMPT
|
|
8
|
+
from TruthTorchLM.utils.dataset_utils import get_dataset
|
|
9
|
+
from TruthTorchLM.utils.eval_utils import metric_score, run_over_dataset
|
|
10
|
+
import wandb
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def evaluate_truth_method(dataset: Union[str, list], model:Union[str,PreTrainedModel], truth_methods: list[TruthMethod], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]=None, eval_metrics:list[str] = ['auroc'],
|
|
14
|
+
correctness_evaluator:CorrectnessEvaluator = ROUGE(0.7), size_of_data = 1.0, previous_context:list =[{'role': 'system', 'content': DEFAULT_SYSTEM_BENCHMARK_PROMPT}],
|
|
15
|
+
user_prompt:str = DEFAULT_USER_PROMPT, seed:int = 0, return_method_details:bool = False, wandb_run = None, wandb_push_method_details:bool = False,
|
|
16
|
+
batch_generation=True, add_generation_prompt = True, continue_final_message = False, split='test', **kwargs):
|
|
17
|
+
|
|
18
|
+
dataset = get_dataset(dataset, size_of_data=size_of_data, seed=seed, split = split)
|
|
19
|
+
|
|
20
|
+
for eval_metric in eval_metrics:
|
|
21
|
+
if eval_metric not in AVAILABLE_EVALUATION_METRICS:
|
|
22
|
+
raise ValueError(f"Evaluation metric {eval_metric} is not available. Available evaluation metrics are: {AVAILABLE_EVALUATION_METRICS}")
|
|
23
|
+
|
|
24
|
+
output_dict = run_over_dataset(dataset, model, truth_methods, tokenizer = tokenizer, correctness_evaluator = correctness_evaluator,
|
|
25
|
+
previous_context = previous_context, user_prompt = user_prompt, seed = seed, return_method_details = return_method_details,
|
|
26
|
+
wandb_run = wandb_run, wandb_push_method_details= wandb_push_method_details,
|
|
27
|
+
batch_generation=batch_generation, add_generation_prompt=add_generation_prompt, continue_final_message=continue_final_message, **kwargs)
|
|
28
|
+
|
|
29
|
+
eval_list = get_metric_scores(output_dict=output_dict, eval_metrics=eval_metrics, seed=seed)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
if wandb_run:
|
|
33
|
+
wandb_run.log({'model_accuracy': sum(output_dict['generation_correctness'])/len(output_dict['generation_correctness'])})
|
|
34
|
+
|
|
35
|
+
eval_dict = eval_list[0]
|
|
36
|
+
for key, _ in eval_dict.items():
|
|
37
|
+
methods = []
|
|
38
|
+
scores = []
|
|
39
|
+
for i, cur_eval_dict in enumerate(eval_list):
|
|
40
|
+
score = cur_eval_dict[key]
|
|
41
|
+
scores.append(score)
|
|
42
|
+
methods.append(str(truth_methods[i].__class__.__name__))
|
|
43
|
+
wandb_run.log({f'{key}_of_method_{i}_{str(truth_methods[i].__class__.__name__)}': score})
|
|
44
|
+
|
|
45
|
+
data = [[method, score] for (method, score) in zip(methods, scores)]
|
|
46
|
+
table = wandb.Table(data=data, columns = ["methods", "scores"])
|
|
47
|
+
wandb.log({f"{key}" : wandb.plot.bar(table, "methods", "scores",
|
|
48
|
+
title=f"{key} Scores of Truth Methods")})
|
|
49
|
+
|
|
50
|
+
return {'eval_list': eval_list, 'output_dict': output_dict}
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_metric_scores(output_dict:dict, eval_metrics:list[str], seed:int=0):
|
|
54
|
+
truth_methods = output_dict['truth_methods']
|
|
55
|
+
eval_list = []
|
|
56
|
+
for i in range(len(truth_methods)):
|
|
57
|
+
eval_dict = metric_score(eval_metrics, output_dict['generation_correctness'], output_dict[f'truth_method_{i}']['truth_values'], output_dict[f'truth_method_{i}']['normalized_truth_values'], seed=seed)
|
|
58
|
+
eval_list.append(eval_dict)
|
|
59
|
+
return eval_list
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from .correctness_evaluator import CorrectnessEvaluator
|
|
2
|
+
import evaluate
|
|
3
|
+
from typing import Union
|
|
4
|
+
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
|
5
|
+
from litellm import completion
|
|
6
|
+
import random
|
|
7
|
+
import torch
|
|
8
|
+
from TruthTorchLM.templates import DEFAULT_JUDGE_PROMPT, DEFAULT_JUDGE_SYSTEM_PROMPT
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ModelJudge(CorrectnessEvaluator):
|
|
12
|
+
def __init__(self, model:Union[PreTrainedModel, str], tokenizer:Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None, prompt:str = DEFAULT_JUDGE_PROMPT, system_prompt:str = DEFAULT_JUDGE_SYSTEM_PROMPT, num_retries:int = 1) -> None:
|
|
13
|
+
super().__init__()
|
|
14
|
+
self.model = model
|
|
15
|
+
self.tokenizer = tokenizer
|
|
16
|
+
self.prompt = prompt
|
|
17
|
+
self.system_prompt = system_prompt
|
|
18
|
+
self.num_retries = num_retries
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def __call__(self, question_text:str, generated_text: str, ground_truths: list[str], seed:int = None) -> bool:
|
|
22
|
+
if seed == None:
|
|
23
|
+
seed = random.randint(0, 1000000)
|
|
24
|
+
|
|
25
|
+
chat = [{"role": "system", "content": self.system_prompt},
|
|
26
|
+
{"role": "user", "content": self.prompt.format(question = question_text, ground_truths = ', '.join(ground_truths), answer = generated_text)}]
|
|
27
|
+
if type(self.model) == str:
|
|
28
|
+
response = completion(
|
|
29
|
+
model=self.model,
|
|
30
|
+
messages=chat,
|
|
31
|
+
seed=seed,
|
|
32
|
+
num_retries=self.num_retries
|
|
33
|
+
)
|
|
34
|
+
generated_text = response.choices[0].message['content']
|
|
35
|
+
else:
|
|
36
|
+
torch.manual_seed(seed)
|
|
37
|
+
random.seed(seed)
|
|
38
|
+
text = self.tokenizer.apply_chat_template(chat, tokenize = False)
|
|
39
|
+
input_ids = self.tokenizer.encode(text, return_tensors="pt").to(self.model.device)
|
|
40
|
+
model_output = self.model.generate(input_ids)
|
|
41
|
+
tokens = model_output[0][len(input_ids[0]):]
|
|
42
|
+
generated_text = self.tokenizer.decode(tokens, skip_special_tokens = False)
|
|
43
|
+
|
|
44
|
+
if 'incorrect' in generated_text.lower():
|
|
45
|
+
return 0
|
|
46
|
+
elif 'correct' in generated_text.lower():
|
|
47
|
+
return 1
|
|
48
|
+
elif "not_attempted" in generated_text.lower():
|
|
49
|
+
return -1
|
|
50
|
+
else:
|
|
51
|
+
#output warning
|
|
52
|
+
print("The output of the judge model is not in the expected format. Not attempted will be returned.")
|
|
53
|
+
return -1
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def __str__(self):
|
|
61
|
+
return f"ROUGE with threshold {self.threshold} and type {self.rouge_type}"
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from .correctness_evaluator import CorrectnessEvaluator
|
|
2
|
+
import evaluate
|
|
3
|
+
|
|
4
|
+
class ROUGE(CorrectnessEvaluator):
|
|
5
|
+
def __init__(self, threshold: float = 0.5, rouge_type: str = 'rougeL'):
|
|
6
|
+
super().__init__()
|
|
7
|
+
self.threshold = threshold
|
|
8
|
+
self.rouge = evaluate.load('rouge')
|
|
9
|
+
self.rouge_type = rouge_type
|
|
10
|
+
|
|
11
|
+
def __call__(self, question_text:str, generated_text: str, ground_truths: list[str], seed:int = None) -> bool:
|
|
12
|
+
for i in range(len(ground_truths)):
|
|
13
|
+
rouge_results = self.rouge.compute(predictions = [generated_text], references=[ground_truths[i]])
|
|
14
|
+
if rouge_results[self.rouge_type] > self.threshold:
|
|
15
|
+
return 1
|
|
16
|
+
return 0
|
|
17
|
+
|
|
18
|
+
def __str__(self):
|
|
19
|
+
return f"ROUGE with threshold {self.threshold} and type {self.rouge_type}"
|
|
@@ -0,0 +1,389 @@
|
|
|
1
|
+
|
|
2
|
+
import copy
|
|
3
|
+
import torch
|
|
4
|
+
import random
|
|
5
|
+
from typing import Union
|
|
6
|
+
from litellm import completion
|
|
7
|
+
|
|
8
|
+
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
|
9
|
+
|
|
10
|
+
#from .truth_methods.truth_method import TruthMethod
|
|
11
|
+
from TruthTorchLM.availability import AVAILABLE_API_MODELS, PROB_AVAILABLE_API_MODELS
|
|
12
|
+
from TruthTorchLM.utils.common_utils import generate, fix_tokenizer_chat
|
|
13
|
+
|
|
14
|
+
import time
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def generate_with_truth_value(model:Union[PreTrainedModel, str], messages:list, question_context:str = None, truth_methods: list = [], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]= None,
|
|
18
|
+
generation_seed=None, batch_generation=True, add_generation_prompt = True, continue_final_message = False, **kwargs)-> dict:
|
|
19
|
+
if type(model) == str:
|
|
20
|
+
return generate_with_truth_value_api(model = model, messages = messages, question_context = question_context, truth_methods = truth_methods, generation_seed = generation_seed, **kwargs)
|
|
21
|
+
else:
|
|
22
|
+
return generate_with_truth_value_hf_local(model = model, messages = messages, question_context = question_context, truth_methods = truth_methods,
|
|
23
|
+
tokenizer = tokenizer, generation_seed = generation_seed, batch_generation=batch_generation, add_generation_prompt = add_generation_prompt, continue_final_message = continue_final_message, **kwargs)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
#TODO: remove number of generations from kwargs if exists
|
|
29
|
+
def generate_with_truth_value_hf_local(model:PreTrainedModel, messages:list, question_context:str = None, truth_methods: list = [],
|
|
30
|
+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None, generation_seed=None, batch_generation=True, add_generation_prompt = True, continue_final_message = False, **kwargs) -> dict:
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
tokenizer, messages = fix_tokenizer_chat(tokenizer, messages)
|
|
34
|
+
text = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt=add_generation_prompt, continue_final_message=continue_final_message)
|
|
35
|
+
if question_context == None:
|
|
36
|
+
question_context = ''
|
|
37
|
+
#search over last user message if exists
|
|
38
|
+
for message in messages[::-1]:
|
|
39
|
+
if message['role'] == 'user':
|
|
40
|
+
question_context = message['content']
|
|
41
|
+
break
|
|
42
|
+
|
|
43
|
+
generated_output = generate(text, model, tokenizer, **kwargs)
|
|
44
|
+
generated_text_return = generated_output['generated_text_skip_specials']
|
|
45
|
+
generated_text = generated_output['generated_text']
|
|
46
|
+
tokens = generated_output['tokens']
|
|
47
|
+
model_output = generated_output['all_ids']
|
|
48
|
+
|
|
49
|
+
#Get sampled generations to be used in truth methods
|
|
50
|
+
number_of_generations, return_text, return_logits, return_logprobs, return_attentions, return_activations = get_sampling_properties(truth_methods)
|
|
51
|
+
|
|
52
|
+
sampled_gen_dict = sample_generations_hf_local(model, text, tokenizer, generation_seed, number_of_generations=number_of_generations,
|
|
53
|
+
return_text=return_text, return_logits=return_logits, return_logprobs=return_logprobs,return_attentions=return_attentions, return_activations=return_activations, batch_generation=batch_generation, **kwargs)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# Get scores from all truth methods
|
|
57
|
+
normalized_truth_values = []
|
|
58
|
+
unnormalized_truth_values = []
|
|
59
|
+
method_spec_outputs = []
|
|
60
|
+
|
|
61
|
+
for truth_method in truth_methods:
|
|
62
|
+
truth_values = truth_method(model=model, input_text=text, generated_text=generated_text, question_context=question_context, all_ids=model_output, tokenizer=tokenizer, generation_seed = generation_seed, sampled_generations_dict=sampled_gen_dict, **kwargs)
|
|
63
|
+
normalized_truth_values.append(truth_values['normalized_truth_value'])
|
|
64
|
+
unnormalized_truth_values.append(truth_values['truth_value'])
|
|
65
|
+
method_spec_outputs.append(truth_values)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# 'all_ids': model_output.cpu(), 'generated_tokens':tokens
|
|
69
|
+
# Create TruthObject
|
|
70
|
+
truth_dict = {'generated_text':generated_text_return, 'normalized_truth_values':normalized_truth_values, 'unnormalized_truth_values':unnormalized_truth_values, 'method_specific_outputs' : method_spec_outputs, 'all_ids': model_output.cpu(), 'generated_tokens':tokens}
|
|
71
|
+
|
|
72
|
+
# Return TruthObject
|
|
73
|
+
return truth_dict
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
#for api-based models, we should write a wrapper function to handle exceptions during the api call
|
|
77
|
+
def generate_with_truth_value_api(model:str, messages:list, question_context:str = None, truth_methods: list = [], generation_seed=None, **kwargs) -> dict:
|
|
78
|
+
# Check if the model is an API model
|
|
79
|
+
if generation_seed is not None:
|
|
80
|
+
random.seed(generation_seed)
|
|
81
|
+
|
|
82
|
+
if type(model) == str and not model in AVAILABLE_API_MODELS:
|
|
83
|
+
raise ValueError(f"model {model} is not supported.")
|
|
84
|
+
|
|
85
|
+
requires_logprobs = False
|
|
86
|
+
for truth_method in truth_methods:
|
|
87
|
+
if truth_method.REQUIRES_LOGPROBS:
|
|
88
|
+
requires_logprobs = True
|
|
89
|
+
|
|
90
|
+
if requires_logprobs and not model in PROB_AVAILABLE_API_MODELS:
|
|
91
|
+
raise ValueError(f"model {model} is not supported for probability requiring truth methods.")
|
|
92
|
+
|
|
93
|
+
if question_context == None:
|
|
94
|
+
question_context = ''
|
|
95
|
+
#search over last user message if exists
|
|
96
|
+
for message in messages[::-1]:
|
|
97
|
+
if message['role'] == 'user':
|
|
98
|
+
question_context = message['content']
|
|
99
|
+
break
|
|
100
|
+
|
|
101
|
+
# Generate the main output
|
|
102
|
+
seed = kwargs.pop('seed', None)
|
|
103
|
+
if seed == None:
|
|
104
|
+
seed = random.randint(0, 1000000)
|
|
105
|
+
kwargs['seed'] = seed #a random seed is generated if seed is not specified
|
|
106
|
+
|
|
107
|
+
response = completion(
|
|
108
|
+
model=model,
|
|
109
|
+
messages=messages,
|
|
110
|
+
logprobs = requires_logprobs,
|
|
111
|
+
**kwargs
|
|
112
|
+
)
|
|
113
|
+
generated_text = response.choices[0].message['content']
|
|
114
|
+
|
|
115
|
+
logprobs = [token['logprob'] for token in response.choices[0].logprobs['content']] if requires_logprobs else None
|
|
116
|
+
generated_tokens = [token['token'] for token in response.choices[0].logprobs['content']] if requires_logprobs else None
|
|
117
|
+
|
|
118
|
+
#Get sampled generations to be used in truth methods
|
|
119
|
+
number_of_generations, return_text, return_logits, return_logprobs, return_attentions, return_activations = get_sampling_properties(truth_methods)
|
|
120
|
+
sampled_gen_dict = sample_generations_api(model, messages, generation_seed, number_of_generations=number_of_generations,
|
|
121
|
+
return_text=return_text, return_logits=return_logits, return_logprobs=return_logprobs,return_attentions=return_attentions, return_activations=return_activations, **kwargs)
|
|
122
|
+
|
|
123
|
+
# Get scores from all truth methods
|
|
124
|
+
normalized_truth_values = []
|
|
125
|
+
unnormalized_truth_values = []
|
|
126
|
+
method_spec_outputs = []
|
|
127
|
+
|
|
128
|
+
for truth_method in truth_methods:
|
|
129
|
+
truth_values = truth_method(model=model, messages=messages, generated_text=generated_text, question_context=question_context, generation_seed=generation_seed, sampled_generations_dict=sampled_gen_dict, logprobs=logprobs, generated_tokens=generated_tokens, **kwargs)
|
|
130
|
+
normalized_truth_values.append(truth_values['normalized_truth_value'])
|
|
131
|
+
unnormalized_truth_values.append(truth_values['truth_value'])
|
|
132
|
+
method_spec_outputs.append(truth_values)
|
|
133
|
+
|
|
134
|
+
# Create TruthObject
|
|
135
|
+
truth_dict = {'generated_text':generated_text, 'normalized_truth_values':normalized_truth_values, 'unnormalized_truth_values':unnormalized_truth_values, 'method_specific_outputs' : method_spec_outputs}
|
|
136
|
+
|
|
137
|
+
# Return TruthObject
|
|
138
|
+
return truth_dict
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def get_sampling_properties(truth_methods:list):
|
|
142
|
+
number_of_generations = 0
|
|
143
|
+
return_text = False
|
|
144
|
+
return_logits = False
|
|
145
|
+
return_logprobs = False
|
|
146
|
+
return_attentions = False
|
|
147
|
+
return_activations = False
|
|
148
|
+
#search over all truth methods for number of generations
|
|
149
|
+
for truth_method in truth_methods:
|
|
150
|
+
if hasattr(truth_method, 'number_of_generations') and truth_method.number_of_generations > number_of_generations:
|
|
151
|
+
number_of_generations = truth_method.number_of_generations
|
|
152
|
+
if truth_method.REQUIRES_SAMPLED_TEXT:
|
|
153
|
+
return_text = True
|
|
154
|
+
if truth_method.REQUIRES_SAMPLED_LOGITS:
|
|
155
|
+
return_logits = True
|
|
156
|
+
if truth_method.REQUIRES_SAMPLED_LOGPROBS:
|
|
157
|
+
return_logprobs = True
|
|
158
|
+
if truth_method.REQUIRES_SAMPLED_ATTENTIONS:
|
|
159
|
+
return_attentions = True
|
|
160
|
+
if truth_method.REQUIRES_SAMPLED_ACTIVATIONS:
|
|
161
|
+
return_activations = True
|
|
162
|
+
return number_of_generations, return_text, return_logits, return_logprobs, return_attentions, return_activations
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def sample_generations_hf_local(model:PreTrainedModel, input_text:str, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], generation_seed:int=None,
|
|
166
|
+
number_of_generations:int = 0, return_text:bool = False, return_logits:bool = False, return_logprobs:bool = False, return_attentions:bool = False, return_activations:bool = False, batch_generation = False, **kwargs):
|
|
167
|
+
|
|
168
|
+
if number_of_generations == 0 or (not return_text and not return_logprobs and not return_activations and not return_attentions and not return_logits):
|
|
169
|
+
return None
|
|
170
|
+
|
|
171
|
+
if generation_seed is not None:
|
|
172
|
+
torch.manual_seed(generation_seed)
|
|
173
|
+
random.seed(generation_seed)
|
|
174
|
+
|
|
175
|
+
if batch_generation == True:
|
|
176
|
+
return sample_generations_batch_hf_local(model=model, input_text=input_text, tokenizer=tokenizer, number_of_generations=number_of_generations,
|
|
177
|
+
return_text=return_text, return_logits=return_logits, return_logprobs=return_logprobs,
|
|
178
|
+
return_attentions=return_attentions, return_activations=return_activations, **kwargs)
|
|
179
|
+
if batch_generation == False:
|
|
180
|
+
return sample_generations_sequential_hf_local(model=model, input_text=input_text, tokenizer=tokenizer, number_of_generations=number_of_generations,
|
|
181
|
+
return_text=return_text, return_logits=return_logits, return_logprobs=return_logprobs,
|
|
182
|
+
return_attentions=return_attentions, return_activations=return_activations, **kwargs)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def sample_generations_batch_hf_local(model:PreTrainedModel, input_text:str, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
|
188
|
+
number_of_generations:int = 0, return_text:bool = False, return_logits:bool = False, return_logprobs:bool = False, return_attentions:bool = False, return_activations:bool = False, return_model_output:bool = True, **kwargs):
|
|
189
|
+
|
|
190
|
+
#number_of_generations, return_text, return_logits, return_logprobs, return_attentions, return_activations = get_sampling_properties(truth_methods)
|
|
191
|
+
|
|
192
|
+
if number_of_generations == 0 or (not return_text and not return_logprobs and not return_activations and not return_attentions and not return_logits):
|
|
193
|
+
return None
|
|
194
|
+
|
|
195
|
+
kwargs = copy.deepcopy(kwargs)
|
|
196
|
+
kwargs.pop('do_sample', None)
|
|
197
|
+
kwargs.pop('num_return_sequences', None)
|
|
198
|
+
kwargs.pop('return_dict_in_generate', None)
|
|
199
|
+
kwargs.pop('output_attentions', None)
|
|
200
|
+
kwargs.pop('output_hidden_states', None)
|
|
201
|
+
kwargs.pop('output_logits', None)
|
|
202
|
+
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
|
|
203
|
+
input_ids = inputs['input_ids']
|
|
204
|
+
|
|
205
|
+
eos_token_id = kwargs.pop("eos_token_id", None)
|
|
206
|
+
|
|
207
|
+
if eos_token_id is None:
|
|
208
|
+
eos_token_id = model.config.eos_token_id
|
|
209
|
+
|
|
210
|
+
pad_token_id = kwargs.pop("pad_token_id", None)
|
|
211
|
+
if pad_token_id is None:
|
|
212
|
+
if type(eos_token_id) == list:
|
|
213
|
+
pad_token_id = eos_token_id[0]
|
|
214
|
+
else:
|
|
215
|
+
pad_token_id = eos_token_id
|
|
216
|
+
|
|
217
|
+
generated_texts = []
|
|
218
|
+
logits_list = []
|
|
219
|
+
logprobs = []
|
|
220
|
+
attentions_list = []
|
|
221
|
+
activations_list = []
|
|
222
|
+
tokens = []
|
|
223
|
+
|
|
224
|
+
with torch.no_grad():
|
|
225
|
+
model_output = model.generate(**inputs, num_return_sequences=number_of_generations, do_sample=True, return_dict_in_generate=True,
|
|
226
|
+
output_attentions=return_attentions, output_hidden_states=return_activations, output_logits=(return_logits or return_logprobs),
|
|
227
|
+
eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)
|
|
228
|
+
|
|
229
|
+
model_output.past_key_values=None
|
|
230
|
+
model_output.sequences = model_output.sequences.cpu()
|
|
231
|
+
if type(eos_token_id) == list:
|
|
232
|
+
temp = torch.stack([torch.argmax((model_output.sequences[:, len(input_ids[0]):] == eos).to(dtype=torch.int), dim=-1) for eos in eos_token_id]).T
|
|
233
|
+
indices = [torch.min(temp[i][temp[i]>0]).item() for i in range(len(temp))]
|
|
234
|
+
else:
|
|
235
|
+
indices = torch.argmax((model_output.sequences[:, len(input_ids[0]):] == eos_token_id).to(dtype=torch.int), dim=-1)
|
|
236
|
+
indices[indices==0] = model_output.sequences.shape[1] - len(input_ids[0]) -1
|
|
237
|
+
if return_text:
|
|
238
|
+
tokens = [seq[len(input_ids[0]):indices[i] + len(input_ids[0])+1].tolist() for i, seq in enumerate(model_output.sequences)]
|
|
239
|
+
generated_texts = tokenizer.batch_decode(tokens, skip_special_tokens=True)
|
|
240
|
+
if return_logprobs or return_logits:
|
|
241
|
+
logits_list = torch.stack(model_output.logits).cpu().permute(1, 0, 2)
|
|
242
|
+
model_output.logits = None
|
|
243
|
+
if return_logprobs:
|
|
244
|
+
logprobs = torch.log_softmax(logits_list, dim=-1) #logprobs for each token
|
|
245
|
+
logprobs = torch.gather(logprobs, dim=-1, index = model_output.sequences[:, len(input_ids[0]):].unsqueeze(-1))#logprobs for each token in the generated text
|
|
246
|
+
logprobs = logprobs.squeeze(-1).tolist()#convert to list
|
|
247
|
+
logprobs = [logprobs[i][:indices[i]+1] for i in range(len(logprobs))]
|
|
248
|
+
if return_logits:
|
|
249
|
+
logits_list = [logits_list[i][:indices[i]+1] for i in range(len(logits_list))]
|
|
250
|
+
else:
|
|
251
|
+
logits_list = []
|
|
252
|
+
if return_activations:
|
|
253
|
+
activations_list = [] #shape = (num gen, num token, num_layer, hidden_state_shape)
|
|
254
|
+
for i in range(number_of_generations): #generation id
|
|
255
|
+
acts = []
|
|
256
|
+
for j in range(indices[i]+1): #token id
|
|
257
|
+
act = []
|
|
258
|
+
for k in range(len(model_output.hidden_states[0])): #layer id
|
|
259
|
+
act.append(model_output.hidden_states[j][k][i].cpu())
|
|
260
|
+
acts.append(act)
|
|
261
|
+
activations_list.append(acts)
|
|
262
|
+
model_output.hidden_states = None
|
|
263
|
+
if return_attentions:
|
|
264
|
+
attentions_list = model_output.attentions
|
|
265
|
+
for i in range(number_of_generations): #generation id
|
|
266
|
+
atts = []
|
|
267
|
+
for j in range(indices[i]+1): #token id
|
|
268
|
+
att = []
|
|
269
|
+
for k in range(len(model_output.attentions[0])): #layer id
|
|
270
|
+
att.append(model_output.attentions[j][k][i].cpu())
|
|
271
|
+
atts.append(att)
|
|
272
|
+
attentions_list.append(atts)
|
|
273
|
+
model_output.attentions = None
|
|
274
|
+
|
|
275
|
+
if not return_model_output:
|
|
276
|
+
model_output.sequences = None
|
|
277
|
+
|
|
278
|
+
return {"generated_texts": generated_texts, "logprobs": logprobs, "activations": activations_list, "logits":logits_list, "attentions":attentions_list, "model_outputs": model_output.sequences, "tokens":tokens}
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def sample_generations_sequential_hf_local(model:PreTrainedModel, input_text:str, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
|
283
|
+
number_of_generations:int = 0, do_sample:bool=True, return_text:bool = False, return_logits:bool = False, return_logprobs:bool = False, return_attentions:bool = False, return_activations:bool = False, return_model_output:bool = True, **kwargs):
|
|
284
|
+
|
|
285
|
+
kwargs = copy.deepcopy(kwargs)
|
|
286
|
+
kwargs.pop('do_sample', None)
|
|
287
|
+
kwargs.pop('num_return_sequences', None)
|
|
288
|
+
kwargs.pop('return_dict_in_generate', None)
|
|
289
|
+
kwargs.pop('output_attentions', None)
|
|
290
|
+
kwargs.pop('output_hidden_states', None)
|
|
291
|
+
kwargs.pop('output_logits', None)
|
|
292
|
+
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
|
|
293
|
+
input_ids = inputs['input_ids']
|
|
294
|
+
|
|
295
|
+
eos_token_id = kwargs.pop("eos_token_id", None)
|
|
296
|
+
|
|
297
|
+
if eos_token_id is None:
|
|
298
|
+
eos_token_id = model.config.eos_token_id
|
|
299
|
+
|
|
300
|
+
generated_texts = []
|
|
301
|
+
logits_list = []
|
|
302
|
+
logprobs_list = []
|
|
303
|
+
attentions_list = []
|
|
304
|
+
activations_list = []
|
|
305
|
+
model_outputs = []
|
|
306
|
+
token_lists = []
|
|
307
|
+
for i in range(number_of_generations):
|
|
308
|
+
with torch.no_grad():
|
|
309
|
+
model_output = model.generate(**inputs, num_return_sequences=1, do_sample=do_sample, return_dict_in_generate=True,
|
|
310
|
+
output_attentions=return_attentions, output_hidden_states=return_activations,
|
|
311
|
+
output_logits=(return_logits or return_logprobs), eos_token_id=eos_token_id, **kwargs)
|
|
312
|
+
model_output.past_key_values=None
|
|
313
|
+
model_output.sequences = model_output.sequences.cpu()
|
|
314
|
+
if return_model_output:
|
|
315
|
+
model_outputs.append(model_output.sequences)
|
|
316
|
+
if return_text:
|
|
317
|
+
tokens = model_output.sequences[0][len(input_ids[0]):]
|
|
318
|
+
generated_text = tokenizer.decode(tokens, skip_special_tokens=True)
|
|
319
|
+
generated_texts.append(generated_text)
|
|
320
|
+
token_lists.append(tokens.tolist())
|
|
321
|
+
if return_logprobs or return_logits:
|
|
322
|
+
logits = torch.cat(model_output.logits).cpu()
|
|
323
|
+
model_output.logits=None
|
|
324
|
+
if return_logprobs:
|
|
325
|
+
logprobs = torch.log_softmax(logits, dim=-1)#logprobs for each token
|
|
326
|
+
logprobs = torch.gather(logprobs, dim=1, index = model_output.sequences[0][len(input_ids[0]):].view(-1, 1))#logprobs for each token in the generated text
|
|
327
|
+
logprobs = logprobs.view(-1).tolist()#convert to list
|
|
328
|
+
logprobs_list.append(logprobs)
|
|
329
|
+
if return_logits:
|
|
330
|
+
logits_list.append(logits)
|
|
331
|
+
if return_activations:
|
|
332
|
+
acts = []
|
|
333
|
+
for i in range(len(model_output.hidden_states)):
|
|
334
|
+
act = []
|
|
335
|
+
for j in range(len(model_output.hidden_states[i])):
|
|
336
|
+
act.append(model_output.hidden_states[i][j][0].cpu())
|
|
337
|
+
acts.append(act)
|
|
338
|
+
activations_list.append(acts)
|
|
339
|
+
model_output.hidden_states = None
|
|
340
|
+
if return_attentions:
|
|
341
|
+
atts = []
|
|
342
|
+
for i in range(len(model_output.attentions)):
|
|
343
|
+
att = []
|
|
344
|
+
for j in range(len(model_output.attentions[i])):
|
|
345
|
+
att.append(model_output.attentions[i][j][0].cpu())
|
|
346
|
+
atts.append(att)
|
|
347
|
+
attentions_list.append(atts)
|
|
348
|
+
model_output.attentions = None
|
|
349
|
+
|
|
350
|
+
return {"generated_texts": generated_texts, "logprobs": logprobs_list, "activations": activations_list, "logits":logits_list, "attentions":attentions_list, "model_outputs": model_outputs, "tokens":token_lists}
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def sample_generations_api(model:str, messages:list, generation_seed:int=None,
|
|
355
|
+
number_of_generations:int = 0, return_text:bool = False, return_logits:bool = False, return_logprobs:bool = False, return_attentions:bool = False, return_activations:bool = False, **kwargs):
|
|
356
|
+
#number_of_generations, return_text, return_logits, return_logprobs, return_attentions, return_activations = get_sampling_properties(truth_methods)
|
|
357
|
+
|
|
358
|
+
if number_of_generations == 0 or (not return_text and not return_logprobs):
|
|
359
|
+
return None
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
if generation_seed is not None:
|
|
363
|
+
random.seed(generation_seed)
|
|
364
|
+
|
|
365
|
+
kwargs = copy.deepcopy(kwargs)
|
|
366
|
+
|
|
367
|
+
generated_texts = []
|
|
368
|
+
logprobs_list = []
|
|
369
|
+
token_lists = []
|
|
370
|
+
for i in range(number_of_generations):
|
|
371
|
+
kwargs.pop('logprobs', None)
|
|
372
|
+
seed = kwargs.pop('seed', None)
|
|
373
|
+
seed = random.randint(0, 1000000)
|
|
374
|
+
kwargs['seed'] = seed
|
|
375
|
+
|
|
376
|
+
response = completion(
|
|
377
|
+
model=model,
|
|
378
|
+
messages=messages,
|
|
379
|
+
logprobs=return_logprobs,
|
|
380
|
+
**kwargs
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
if return_text:
|
|
384
|
+
generated_texts.append(response.choices[0].message['content'])
|
|
385
|
+
if return_logprobs:
|
|
386
|
+
logprobs_list.append([token['logprob'] for token in response.choices[0].logprobs['content']])
|
|
387
|
+
token_lists.append([token['token'] for token in response.choices[0].logprobs['content']])
|
|
388
|
+
|
|
389
|
+
return {"generated_texts": generated_texts, "logprobs": logprobs_list, "tokens":token_lists}
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from .decomposition_method import FactualDecompositionMethod
|
|
2
|
+
from .unstructured_decomposition_api import UnstructuredDecompositionAPI
|
|
3
|
+
from .unstructured_decomposition_local import UnstructuredDecompositionLocal
|
|
4
|
+
from .structured_decomposition_api import StructuredDecompositionAPI
|
|
5
|
+
from .structured_decomposition_local import StructuredDecompositionLocal
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
__all__ = ['UnstructuredDecompositionAPI', 'UnstructuredDecompositionLocal', 'StructuredDecompositionAPI', 'StructuredDecompositionLocal']
|