TruthTorchLM 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. TruthTorchLM/__init__.py +16 -0
  2. TruthTorchLM/availability.py +14 -0
  3. TruthTorchLM/calibration.py +36 -0
  4. TruthTorchLM/evaluators/__init__.py +8 -0
  5. TruthTorchLM/evaluators/bleu.py +20 -0
  6. TruthTorchLM/evaluators/correctness_evaluator.py +14 -0
  7. TruthTorchLM/evaluators/eval_truth_method.py +59 -0
  8. TruthTorchLM/evaluators/model_judge.py +61 -0
  9. TruthTorchLM/evaluators/rouge.py +19 -0
  10. TruthTorchLM/generation.py +389 -0
  11. TruthTorchLM/long_form_generation/__init__.py +5 -0
  12. TruthTorchLM/long_form_generation/decomposition_methods/__init__.py +8 -0
  13. TruthTorchLM/long_form_generation/decomposition_methods/decomposition_method.py +27 -0
  14. TruthTorchLM/long_form_generation/decomposition_methods/structured_decomposition_api.py +50 -0
  15. TruthTorchLM/long_form_generation/decomposition_methods/structured_decomposition_local.py +43 -0
  16. TruthTorchLM/long_form_generation/decomposition_methods/unstructured_decomposition_api.py +50 -0
  17. TruthTorchLM/long_form_generation/decomposition_methods/unstructured_decomposition_local.py +65 -0
  18. TruthTorchLM/long_form_generation/evaluators/__init__.py +4 -0
  19. TruthTorchLM/long_form_generation/evaluators/eval_claim.py +223 -0
  20. TruthTorchLM/long_form_generation/evaluators/long_gen_eval.py +158 -0
  21. TruthTorchLM/long_form_generation/generation.py +167 -0
  22. TruthTorchLM/long_form_generation/statement_check_methods/__init__.py +7 -0
  23. TruthTorchLM/long_form_generation/statement_check_methods/answer_statement_entailment.py +219 -0
  24. TruthTorchLM/long_form_generation/statement_check_methods/question_answer_generation.py +354 -0
  25. TruthTorchLM/long_form_generation/statement_check_methods/question_generation.py +293 -0
  26. TruthTorchLM/long_form_generation/statement_check_methods/statement_check_method.py +46 -0
  27. TruthTorchLM/long_form_generation/utils/__init__.py +3 -0
  28. TruthTorchLM/long_form_generation/utils/dataset_utils.py +90 -0
  29. TruthTorchLM/long_form_generation/utils/eval_utils.py +188 -0
  30. TruthTorchLM/long_form_generation/utils/safe_utils.py +231 -0
  31. TruthTorchLM/normalizers/__init__.py +4 -0
  32. TruthTorchLM/normalizers/normalizer.py +36 -0
  33. TruthTorchLM/normalizers/sigmoid_normalizer.py +34 -0
  34. TruthTorchLM/scoring_methods/__init__.py +5 -0
  35. TruthTorchLM/scoring_methods/length_normalized_scoring.py +12 -0
  36. TruthTorchLM/scoring_methods/log_prob_scoring.py +11 -0
  37. TruthTorchLM/scoring_methods/scoring_method.py +19 -0
  38. TruthTorchLM/templates.py +169 -0
  39. TruthTorchLM/truth_methods/__init__.py +31 -0
  40. TruthTorchLM/truth_methods/attention_score.py +52 -0
  41. TruthTorchLM/truth_methods/confidence.py +59 -0
  42. TruthTorchLM/truth_methods/cross_examination.py +164 -0
  43. TruthTorchLM/truth_methods/eccentricity_confidence.py +74 -0
  44. TruthTorchLM/truth_methods/eccentricity_uncertainty.py +69 -0
  45. TruthTorchLM/truth_methods/entropy.py +66 -0
  46. TruthTorchLM/truth_methods/google_search_check.py +144 -0
  47. TruthTorchLM/truth_methods/inside.py +49 -0
  48. TruthTorchLM/truth_methods/kernel_language_entropy.py +81 -0
  49. TruthTorchLM/truth_methods/lars.py +479 -0
  50. TruthTorchLM/truth_methods/mars.py +196 -0
  51. TruthTorchLM/truth_methods/matrix_degree_confidence.py +78 -0
  52. TruthTorchLM/truth_methods/matrix_degree_uncertainty.py +74 -0
  53. TruthTorchLM/truth_methods/multi_llm_collab.py +535 -0
  54. TruthTorchLM/truth_methods/num_semantic_set_uncertainty.py +70 -0
  55. TruthTorchLM/truth_methods/p_true.py +71 -0
  56. TruthTorchLM/truth_methods/saplma.py +206 -0
  57. TruthTorchLM/truth_methods/self_detection.py +133 -0
  58. TruthTorchLM/truth_methods/semantic_entropy.py +93 -0
  59. TruthTorchLM/truth_methods/sentSAR.py +101 -0
  60. TruthTorchLM/truth_methods/sum_eigen_uncertainty.py +71 -0
  61. TruthTorchLM/truth_methods/tokenSAR.py +76 -0
  62. TruthTorchLM/truth_methods/truth_method.py +73 -0
  63. TruthTorchLM/truth_methods/verbalized_confidence.py +77 -0
  64. TruthTorchLM/utils/__init__.py +5 -0
  65. TruthTorchLM/utils/calibration_utils.py +64 -0
  66. TruthTorchLM/utils/common_utils.py +374 -0
  67. TruthTorchLM/utils/dataset_utils.py +127 -0
  68. TruthTorchLM/utils/eval_utils.py +280 -0
  69. TruthTorchLM/utils/google_search_utils.py +136 -0
  70. truthtorchlm-0.0.0.dist-info/LICENSE +21 -0
  71. truthtorchlm-0.0.0.dist-info/LICENSE copy +21 -0
  72. truthtorchlm-0.0.0.dist-info/METADATA +209 -0
  73. truthtorchlm-0.0.0.dist-info/RECORD +75 -0
  74. truthtorchlm-0.0.0.dist-info/WHEEL +5 -0
  75. truthtorchlm-0.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,16 @@
1
+ from .truth_methods.truth_method import TruthMethod
2
+ from TruthTorchLM import utils ##TODO do we really need to import this?
3
+ from TruthTorchLM import scoring_methods
4
+ from TruthTorchLM import truth_methods
5
+ from .generation import generate_with_truth_value
6
+ from .calibration import calibrate_truth_method
7
+ from TruthTorchLM import evaluators
8
+ from .evaluators import evaluate_truth_method
9
+ from .templates import DEFAULT_USER_PROMPT, DEFAULT_SYSTEM_PROMPT ##TODO import all?
10
+ from .availability import AVAILABLE_DATASETS, AVAILABLE_EVALUATION_METRICS
11
+ from TruthTorchLM import normalizers
12
+
13
+ from TruthTorchLM import long_form_generation
14
+
15
+
16
+ #__all__ = ['generate_with_truth_value']
@@ -0,0 +1,14 @@
1
+ AVAILABLE_API_MODELS = ['gpt-4o', 'gpt-4o-2024-05-13', 'gpt-4o-2024-08-06', 'chatgpt-4o-latest', 'gpt-4o-mini', 'gpt-4o-mini-2024-07-18',
2
+ 'gpt-4-turbo','gpt-4-turbo-2024-04-09', 'gpt-4-turbo-preview', 'gpt-4-0125-preview', 'gpt-4-1106-preview', 'gpt-4',
3
+ 'gpt-4-0613', 'gpt-4-0314', 'gpt-3.5-turbo-0125', 'gpt-3.5-turbo', 'gpt-3.5-turbo-1106', 'gpt-3.5-turbo-instruct', 'together_ai/togethercomputer/llama-2-70b']
4
+
5
+ PROB_AVAILABLE_API_MODELS = ['gpt-4o', 'gpt-4o-2024-05-13', 'gpt-4o-2024-08-06', 'chatgpt-4o-latest', 'gpt-4o-mini', 'gpt-4o-mini-2024-07-18',
6
+ 'gpt-4-turbo','gpt-4-turbo-2024-04-09', 'gpt-4-turbo-preview', 'gpt-4-0125-preview', 'gpt-4-1106-preview', 'gpt-4',
7
+ 'gpt-4-0613', 'gpt-4-0314', 'gpt-3.5-turbo-0125', 'gpt-3.5-turbo', 'gpt-3.5-turbo-1106', 'gpt-3.5-turbo-instruct', 'together_ai/togethercomputer/llama-2-70b']
8
+
9
+ ACTIVATION_AVAILABLE_API_MODELS = []
10
+
11
+ AVAILABLE_DATASETS = ['trivia_qa', 'gsm8k', 'natural_qa', 'pop_qa', 'simple_qa']
12
+ LONG_FORM_AVAILABLE_DATASETS = ['longfact_concepts', 'longfact_objects']
13
+
14
+ AVAILABLE_EVALUATION_METRICS = ['auroc', 'auprc', 'auarc', 'accuracy', 'f1', 'precision', 'recall', 'prr']
@@ -0,0 +1,36 @@
1
+ from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
2
+ from typing import Union
3
+ from TruthTorchLM.truth_methods import TruthMethod
4
+ from TruthTorchLM.evaluators import CorrectnessEvaluator, ROUGE
5
+ from TruthTorchLM.templates import DEFAULT_SYSTEM_BENCHMARK_PROMPT, DEFAULT_USER_PROMPT
6
+ from TruthTorchLM.utils.dataset_utils import get_dataset
7
+ from TruthTorchLM.utils.eval_utils import run_over_dataset
8
+ import numpy as np
9
+
10
+
11
+
12
+ def calibrate_truth_method(dataset: Union[str, list], model:Union[str,PreTrainedModel], truth_methods: list[TruthMethod], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] =None,
13
+ correctness_evaluator:CorrectnessEvaluator = ROUGE(0.7), size_of_data:float = 1.0, previous_context:list =[{'role': 'system', 'content': DEFAULT_SYSTEM_BENCHMARK_PROMPT}],
14
+ user_prompt:str = DEFAULT_USER_PROMPT, seed:int = 0, wandb_run = None, return_method_details:bool = False, wandb_push_method_details:bool = False, split = 'train', **kwargs):
15
+
16
+ dataset = get_dataset(dataset, size_of_data=size_of_data, seed=seed, split = split)
17
+
18
+ output_dict = run_over_dataset(dataset, model, truth_methods, tokenizer = tokenizer, correctness_evaluator = correctness_evaluator,
19
+ previous_context = previous_context, user_prompt = user_prompt, seed = seed, return_method_details = return_method_details,
20
+ wandb_run = wandb_run, wandb_push_method_details = wandb_push_method_details, **kwargs)
21
+
22
+ for i, truth_method in enumerate(truth_methods):
23
+ truth_values = output_dict[f'truth_method_{i}']['truth_values']
24
+ truth_values = np.array(truth_values)
25
+ truth_values[np.isnan(truth_values)] = 0
26
+ correctness = output_dict['generation_correctness']
27
+ #if generation_correctness is -1, it means that the model didn't attempt to generate an answer, remove those from the evaluation
28
+ truth_method.normalizer.calibrate(generation_performance_scores=correctness, truth_values=truth_values)
29
+ return output_dict
30
+
31
+
32
+
33
+
34
+
35
+
36
+
@@ -0,0 +1,8 @@
1
+ from .correctness_evaluator import CorrectnessEvaluator
2
+ from .rouge import ROUGE
3
+ from .bleu import BLEU
4
+ from .model_judge import ModelJudge
5
+ from .eval_truth_method import evaluate_truth_method, get_metric_scores
6
+
7
+ __all__ = ['CorrectnessEvaluator', 'ROUGE', 'BLEU', 'evaluate_truth_method', 'ModelJudge', 'get_metric_scores']
8
+
@@ -0,0 +1,20 @@
1
+
2
+ from .correctness_evaluator import CorrectnessEvaluator
3
+ import evaluate
4
+
5
+
6
+ class BLEU(CorrectnessEvaluator):
7
+ def __init__(self, threshold: float = 0.5):
8
+ super().__init__()
9
+ self.threshold = threshold
10
+ self.bleu = evaluate.load('bleu')
11
+
12
+ def __call__(self, question_text:str, generated_text: str, ground_truths: list[str], seed:int = None) -> bool:
13
+ for i in range(len(ground_truths)):
14
+ bleu_results = self.bleu.compute(predictions = [generated_text], references=[ground_truths[i]])
15
+ if bleu_results['bleu'] > self.threshold:
16
+ return 1
17
+ return 0
18
+
19
+ def __str__(self):
20
+ return f"BLEU with threshold {self.threshold}"
@@ -0,0 +1,14 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ class CorrectnessEvaluator(ABC):
4
+
5
+ def __init__(self):
6
+ pass
7
+
8
+ @abstractmethod
9
+ def __call__(self, question_text:str, generated_text: str, ground_truth_text: list[str], seed:int = None) -> int:
10
+ raise NotImplementedError("Subclasses must implement this method")
11
+
12
+ @abstractmethod
13
+ def __str__(self):
14
+ raise NotImplementedError("Subclasses must implement this method")
@@ -0,0 +1,59 @@
1
+ from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
2
+ from typing import Union
3
+ from TruthTorchLM.truth_methods import TruthMethod
4
+ from .correctness_evaluator import CorrectnessEvaluator
5
+ from .rouge import ROUGE
6
+ from TruthTorchLM.availability import AVAILABLE_EVALUATION_METRICS
7
+ from TruthTorchLM.templates import DEFAULT_SYSTEM_BENCHMARK_PROMPT, DEFAULT_USER_PROMPT
8
+ from TruthTorchLM.utils.dataset_utils import get_dataset
9
+ from TruthTorchLM.utils.eval_utils import metric_score, run_over_dataset
10
+ import wandb
11
+
12
+
13
+ def evaluate_truth_method(dataset: Union[str, list], model:Union[str,PreTrainedModel], truth_methods: list[TruthMethod], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]=None, eval_metrics:list[str] = ['auroc'],
14
+ correctness_evaluator:CorrectnessEvaluator = ROUGE(0.7), size_of_data = 1.0, previous_context:list =[{'role': 'system', 'content': DEFAULT_SYSTEM_BENCHMARK_PROMPT}],
15
+ user_prompt:str = DEFAULT_USER_PROMPT, seed:int = 0, return_method_details:bool = False, wandb_run = None, wandb_push_method_details:bool = False,
16
+ batch_generation=True, add_generation_prompt = True, continue_final_message = False, split='test', **kwargs):
17
+
18
+ dataset = get_dataset(dataset, size_of_data=size_of_data, seed=seed, split = split)
19
+
20
+ for eval_metric in eval_metrics:
21
+ if eval_metric not in AVAILABLE_EVALUATION_METRICS:
22
+ raise ValueError(f"Evaluation metric {eval_metric} is not available. Available evaluation metrics are: {AVAILABLE_EVALUATION_METRICS}")
23
+
24
+ output_dict = run_over_dataset(dataset, model, truth_methods, tokenizer = tokenizer, correctness_evaluator = correctness_evaluator,
25
+ previous_context = previous_context, user_prompt = user_prompt, seed = seed, return_method_details = return_method_details,
26
+ wandb_run = wandb_run, wandb_push_method_details= wandb_push_method_details,
27
+ batch_generation=batch_generation, add_generation_prompt=add_generation_prompt, continue_final_message=continue_final_message, **kwargs)
28
+
29
+ eval_list = get_metric_scores(output_dict=output_dict, eval_metrics=eval_metrics, seed=seed)
30
+
31
+
32
+ if wandb_run:
33
+ wandb_run.log({'model_accuracy': sum(output_dict['generation_correctness'])/len(output_dict['generation_correctness'])})
34
+
35
+ eval_dict = eval_list[0]
36
+ for key, _ in eval_dict.items():
37
+ methods = []
38
+ scores = []
39
+ for i, cur_eval_dict in enumerate(eval_list):
40
+ score = cur_eval_dict[key]
41
+ scores.append(score)
42
+ methods.append(str(truth_methods[i].__class__.__name__))
43
+ wandb_run.log({f'{key}_of_method_{i}_{str(truth_methods[i].__class__.__name__)}': score})
44
+
45
+ data = [[method, score] for (method, score) in zip(methods, scores)]
46
+ table = wandb.Table(data=data, columns = ["methods", "scores"])
47
+ wandb.log({f"{key}" : wandb.plot.bar(table, "methods", "scores",
48
+ title=f"{key} Scores of Truth Methods")})
49
+
50
+ return {'eval_list': eval_list, 'output_dict': output_dict}
51
+
52
+
53
+ def get_metric_scores(output_dict:dict, eval_metrics:list[str], seed:int=0):
54
+ truth_methods = output_dict['truth_methods']
55
+ eval_list = []
56
+ for i in range(len(truth_methods)):
57
+ eval_dict = metric_score(eval_metrics, output_dict['generation_correctness'], output_dict[f'truth_method_{i}']['truth_values'], output_dict[f'truth_method_{i}']['normalized_truth_values'], seed=seed)
58
+ eval_list.append(eval_dict)
59
+ return eval_list
@@ -0,0 +1,61 @@
1
+ from .correctness_evaluator import CorrectnessEvaluator
2
+ import evaluate
3
+ from typing import Union
4
+ from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
5
+ from litellm import completion
6
+ import random
7
+ import torch
8
+ from TruthTorchLM.templates import DEFAULT_JUDGE_PROMPT, DEFAULT_JUDGE_SYSTEM_PROMPT
9
+
10
+
11
+ class ModelJudge(CorrectnessEvaluator):
12
+ def __init__(self, model:Union[PreTrainedModel, str], tokenizer:Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None, prompt:str = DEFAULT_JUDGE_PROMPT, system_prompt:str = DEFAULT_JUDGE_SYSTEM_PROMPT, num_retries:int = 1) -> None:
13
+ super().__init__()
14
+ self.model = model
15
+ self.tokenizer = tokenizer
16
+ self.prompt = prompt
17
+ self.system_prompt = system_prompt
18
+ self.num_retries = num_retries
19
+
20
+
21
+ def __call__(self, question_text:str, generated_text: str, ground_truths: list[str], seed:int = None) -> bool:
22
+ if seed == None:
23
+ seed = random.randint(0, 1000000)
24
+
25
+ chat = [{"role": "system", "content": self.system_prompt},
26
+ {"role": "user", "content": self.prompt.format(question = question_text, ground_truths = ', '.join(ground_truths), answer = generated_text)}]
27
+ if type(self.model) == str:
28
+ response = completion(
29
+ model=self.model,
30
+ messages=chat,
31
+ seed=seed,
32
+ num_retries=self.num_retries
33
+ )
34
+ generated_text = response.choices[0].message['content']
35
+ else:
36
+ torch.manual_seed(seed)
37
+ random.seed(seed)
38
+ text = self.tokenizer.apply_chat_template(chat, tokenize = False)
39
+ input_ids = self.tokenizer.encode(text, return_tensors="pt").to(self.model.device)
40
+ model_output = self.model.generate(input_ids)
41
+ tokens = model_output[0][len(input_ids[0]):]
42
+ generated_text = self.tokenizer.decode(tokens, skip_special_tokens = False)
43
+
44
+ if 'incorrect' in generated_text.lower():
45
+ return 0
46
+ elif 'correct' in generated_text.lower():
47
+ return 1
48
+ elif "not_attempted" in generated_text.lower():
49
+ return -1
50
+ else:
51
+ #output warning
52
+ print("The output of the judge model is not in the expected format. Not attempted will be returned.")
53
+ return -1
54
+
55
+
56
+
57
+
58
+
59
+
60
+ def __str__(self):
61
+ return f"ROUGE with threshold {self.threshold} and type {self.rouge_type}"
@@ -0,0 +1,19 @@
1
+ from .correctness_evaluator import CorrectnessEvaluator
2
+ import evaluate
3
+
4
+ class ROUGE(CorrectnessEvaluator):
5
+ def __init__(self, threshold: float = 0.5, rouge_type: str = 'rougeL'):
6
+ super().__init__()
7
+ self.threshold = threshold
8
+ self.rouge = evaluate.load('rouge')
9
+ self.rouge_type = rouge_type
10
+
11
+ def __call__(self, question_text:str, generated_text: str, ground_truths: list[str], seed:int = None) -> bool:
12
+ for i in range(len(ground_truths)):
13
+ rouge_results = self.rouge.compute(predictions = [generated_text], references=[ground_truths[i]])
14
+ if rouge_results[self.rouge_type] > self.threshold:
15
+ return 1
16
+ return 0
17
+
18
+ def __str__(self):
19
+ return f"ROUGE with threshold {self.threshold} and type {self.rouge_type}"
@@ -0,0 +1,389 @@
1
+
2
+ import copy
3
+ import torch
4
+ import random
5
+ from typing import Union
6
+ from litellm import completion
7
+
8
+ from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
9
+
10
+ #from .truth_methods.truth_method import TruthMethod
11
+ from TruthTorchLM.availability import AVAILABLE_API_MODELS, PROB_AVAILABLE_API_MODELS
12
+ from TruthTorchLM.utils.common_utils import generate, fix_tokenizer_chat
13
+
14
+ import time
15
+
16
+
17
+ def generate_with_truth_value(model:Union[PreTrainedModel, str], messages:list, question_context:str = None, truth_methods: list = [], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]= None,
18
+ generation_seed=None, batch_generation=True, add_generation_prompt = True, continue_final_message = False, **kwargs)-> dict:
19
+ if type(model) == str:
20
+ return generate_with_truth_value_api(model = model, messages = messages, question_context = question_context, truth_methods = truth_methods, generation_seed = generation_seed, **kwargs)
21
+ else:
22
+ return generate_with_truth_value_hf_local(model = model, messages = messages, question_context = question_context, truth_methods = truth_methods,
23
+ tokenizer = tokenizer, generation_seed = generation_seed, batch_generation=batch_generation, add_generation_prompt = add_generation_prompt, continue_final_message = continue_final_message, **kwargs)
24
+
25
+
26
+
27
+
28
+ #TODO: remove number of generations from kwargs if exists
29
+ def generate_with_truth_value_hf_local(model:PreTrainedModel, messages:list, question_context:str = None, truth_methods: list = [],
30
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None, generation_seed=None, batch_generation=True, add_generation_prompt = True, continue_final_message = False, **kwargs) -> dict:
31
+
32
+
33
+ tokenizer, messages = fix_tokenizer_chat(tokenizer, messages)
34
+ text = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt=add_generation_prompt, continue_final_message=continue_final_message)
35
+ if question_context == None:
36
+ question_context = ''
37
+ #search over last user message if exists
38
+ for message in messages[::-1]:
39
+ if message['role'] == 'user':
40
+ question_context = message['content']
41
+ break
42
+
43
+ generated_output = generate(text, model, tokenizer, **kwargs)
44
+ generated_text_return = generated_output['generated_text_skip_specials']
45
+ generated_text = generated_output['generated_text']
46
+ tokens = generated_output['tokens']
47
+ model_output = generated_output['all_ids']
48
+
49
+ #Get sampled generations to be used in truth methods
50
+ number_of_generations, return_text, return_logits, return_logprobs, return_attentions, return_activations = get_sampling_properties(truth_methods)
51
+
52
+ sampled_gen_dict = sample_generations_hf_local(model, text, tokenizer, generation_seed, number_of_generations=number_of_generations,
53
+ return_text=return_text, return_logits=return_logits, return_logprobs=return_logprobs,return_attentions=return_attentions, return_activations=return_activations, batch_generation=batch_generation, **kwargs)
54
+
55
+
56
+ # Get scores from all truth methods
57
+ normalized_truth_values = []
58
+ unnormalized_truth_values = []
59
+ method_spec_outputs = []
60
+
61
+ for truth_method in truth_methods:
62
+ truth_values = truth_method(model=model, input_text=text, generated_text=generated_text, question_context=question_context, all_ids=model_output, tokenizer=tokenizer, generation_seed = generation_seed, sampled_generations_dict=sampled_gen_dict, **kwargs)
63
+ normalized_truth_values.append(truth_values['normalized_truth_value'])
64
+ unnormalized_truth_values.append(truth_values['truth_value'])
65
+ method_spec_outputs.append(truth_values)
66
+
67
+
68
+ # 'all_ids': model_output.cpu(), 'generated_tokens':tokens
69
+ # Create TruthObject
70
+ truth_dict = {'generated_text':generated_text_return, 'normalized_truth_values':normalized_truth_values, 'unnormalized_truth_values':unnormalized_truth_values, 'method_specific_outputs' : method_spec_outputs, 'all_ids': model_output.cpu(), 'generated_tokens':tokens}
71
+
72
+ # Return TruthObject
73
+ return truth_dict
74
+
75
+
76
+ #for api-based models, we should write a wrapper function to handle exceptions during the api call
77
+ def generate_with_truth_value_api(model:str, messages:list, question_context:str = None, truth_methods: list = [], generation_seed=None, **kwargs) -> dict:
78
+ # Check if the model is an API model
79
+ if generation_seed is not None:
80
+ random.seed(generation_seed)
81
+
82
+ if type(model) == str and not model in AVAILABLE_API_MODELS:
83
+ raise ValueError(f"model {model} is not supported.")
84
+
85
+ requires_logprobs = False
86
+ for truth_method in truth_methods:
87
+ if truth_method.REQUIRES_LOGPROBS:
88
+ requires_logprobs = True
89
+
90
+ if requires_logprobs and not model in PROB_AVAILABLE_API_MODELS:
91
+ raise ValueError(f"model {model} is not supported for probability requiring truth methods.")
92
+
93
+ if question_context == None:
94
+ question_context = ''
95
+ #search over last user message if exists
96
+ for message in messages[::-1]:
97
+ if message['role'] == 'user':
98
+ question_context = message['content']
99
+ break
100
+
101
+ # Generate the main output
102
+ seed = kwargs.pop('seed', None)
103
+ if seed == None:
104
+ seed = random.randint(0, 1000000)
105
+ kwargs['seed'] = seed #a random seed is generated if seed is not specified
106
+
107
+ response = completion(
108
+ model=model,
109
+ messages=messages,
110
+ logprobs = requires_logprobs,
111
+ **kwargs
112
+ )
113
+ generated_text = response.choices[0].message['content']
114
+
115
+ logprobs = [token['logprob'] for token in response.choices[0].logprobs['content']] if requires_logprobs else None
116
+ generated_tokens = [token['token'] for token in response.choices[0].logprobs['content']] if requires_logprobs else None
117
+
118
+ #Get sampled generations to be used in truth methods
119
+ number_of_generations, return_text, return_logits, return_logprobs, return_attentions, return_activations = get_sampling_properties(truth_methods)
120
+ sampled_gen_dict = sample_generations_api(model, messages, generation_seed, number_of_generations=number_of_generations,
121
+ return_text=return_text, return_logits=return_logits, return_logprobs=return_logprobs,return_attentions=return_attentions, return_activations=return_activations, **kwargs)
122
+
123
+ # Get scores from all truth methods
124
+ normalized_truth_values = []
125
+ unnormalized_truth_values = []
126
+ method_spec_outputs = []
127
+
128
+ for truth_method in truth_methods:
129
+ truth_values = truth_method(model=model, messages=messages, generated_text=generated_text, question_context=question_context, generation_seed=generation_seed, sampled_generations_dict=sampled_gen_dict, logprobs=logprobs, generated_tokens=generated_tokens, **kwargs)
130
+ normalized_truth_values.append(truth_values['normalized_truth_value'])
131
+ unnormalized_truth_values.append(truth_values['truth_value'])
132
+ method_spec_outputs.append(truth_values)
133
+
134
+ # Create TruthObject
135
+ truth_dict = {'generated_text':generated_text, 'normalized_truth_values':normalized_truth_values, 'unnormalized_truth_values':unnormalized_truth_values, 'method_specific_outputs' : method_spec_outputs}
136
+
137
+ # Return TruthObject
138
+ return truth_dict
139
+
140
+
141
+ def get_sampling_properties(truth_methods:list):
142
+ number_of_generations = 0
143
+ return_text = False
144
+ return_logits = False
145
+ return_logprobs = False
146
+ return_attentions = False
147
+ return_activations = False
148
+ #search over all truth methods for number of generations
149
+ for truth_method in truth_methods:
150
+ if hasattr(truth_method, 'number_of_generations') and truth_method.number_of_generations > number_of_generations:
151
+ number_of_generations = truth_method.number_of_generations
152
+ if truth_method.REQUIRES_SAMPLED_TEXT:
153
+ return_text = True
154
+ if truth_method.REQUIRES_SAMPLED_LOGITS:
155
+ return_logits = True
156
+ if truth_method.REQUIRES_SAMPLED_LOGPROBS:
157
+ return_logprobs = True
158
+ if truth_method.REQUIRES_SAMPLED_ATTENTIONS:
159
+ return_attentions = True
160
+ if truth_method.REQUIRES_SAMPLED_ACTIVATIONS:
161
+ return_activations = True
162
+ return number_of_generations, return_text, return_logits, return_logprobs, return_attentions, return_activations
163
+
164
+
165
+ def sample_generations_hf_local(model:PreTrainedModel, input_text:str, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], generation_seed:int=None,
166
+ number_of_generations:int = 0, return_text:bool = False, return_logits:bool = False, return_logprobs:bool = False, return_attentions:bool = False, return_activations:bool = False, batch_generation = False, **kwargs):
167
+
168
+ if number_of_generations == 0 or (not return_text and not return_logprobs and not return_activations and not return_attentions and not return_logits):
169
+ return None
170
+
171
+ if generation_seed is not None:
172
+ torch.manual_seed(generation_seed)
173
+ random.seed(generation_seed)
174
+
175
+ if batch_generation == True:
176
+ return sample_generations_batch_hf_local(model=model, input_text=input_text, tokenizer=tokenizer, number_of_generations=number_of_generations,
177
+ return_text=return_text, return_logits=return_logits, return_logprobs=return_logprobs,
178
+ return_attentions=return_attentions, return_activations=return_activations, **kwargs)
179
+ if batch_generation == False:
180
+ return sample_generations_sequential_hf_local(model=model, input_text=input_text, tokenizer=tokenizer, number_of_generations=number_of_generations,
181
+ return_text=return_text, return_logits=return_logits, return_logprobs=return_logprobs,
182
+ return_attentions=return_attentions, return_activations=return_activations, **kwargs)
183
+
184
+
185
+
186
+
187
+ def sample_generations_batch_hf_local(model:PreTrainedModel, input_text:str, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
188
+ number_of_generations:int = 0, return_text:bool = False, return_logits:bool = False, return_logprobs:bool = False, return_attentions:bool = False, return_activations:bool = False, return_model_output:bool = True, **kwargs):
189
+
190
+ #number_of_generations, return_text, return_logits, return_logprobs, return_attentions, return_activations = get_sampling_properties(truth_methods)
191
+
192
+ if number_of_generations == 0 or (not return_text and not return_logprobs and not return_activations and not return_attentions and not return_logits):
193
+ return None
194
+
195
+ kwargs = copy.deepcopy(kwargs)
196
+ kwargs.pop('do_sample', None)
197
+ kwargs.pop('num_return_sequences', None)
198
+ kwargs.pop('return_dict_in_generate', None)
199
+ kwargs.pop('output_attentions', None)
200
+ kwargs.pop('output_hidden_states', None)
201
+ kwargs.pop('output_logits', None)
202
+ inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
203
+ input_ids = inputs['input_ids']
204
+
205
+ eos_token_id = kwargs.pop("eos_token_id", None)
206
+
207
+ if eos_token_id is None:
208
+ eos_token_id = model.config.eos_token_id
209
+
210
+ pad_token_id = kwargs.pop("pad_token_id", None)
211
+ if pad_token_id is None:
212
+ if type(eos_token_id) == list:
213
+ pad_token_id = eos_token_id[0]
214
+ else:
215
+ pad_token_id = eos_token_id
216
+
217
+ generated_texts = []
218
+ logits_list = []
219
+ logprobs = []
220
+ attentions_list = []
221
+ activations_list = []
222
+ tokens = []
223
+
224
+ with torch.no_grad():
225
+ model_output = model.generate(**inputs, num_return_sequences=number_of_generations, do_sample=True, return_dict_in_generate=True,
226
+ output_attentions=return_attentions, output_hidden_states=return_activations, output_logits=(return_logits or return_logprobs),
227
+ eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)
228
+
229
+ model_output.past_key_values=None
230
+ model_output.sequences = model_output.sequences.cpu()
231
+ if type(eos_token_id) == list:
232
+ temp = torch.stack([torch.argmax((model_output.sequences[:, len(input_ids[0]):] == eos).to(dtype=torch.int), dim=-1) for eos in eos_token_id]).T
233
+ indices = [torch.min(temp[i][temp[i]>0]).item() for i in range(len(temp))]
234
+ else:
235
+ indices = torch.argmax((model_output.sequences[:, len(input_ids[0]):] == eos_token_id).to(dtype=torch.int), dim=-1)
236
+ indices[indices==0] = model_output.sequences.shape[1] - len(input_ids[0]) -1
237
+ if return_text:
238
+ tokens = [seq[len(input_ids[0]):indices[i] + len(input_ids[0])+1].tolist() for i, seq in enumerate(model_output.sequences)]
239
+ generated_texts = tokenizer.batch_decode(tokens, skip_special_tokens=True)
240
+ if return_logprobs or return_logits:
241
+ logits_list = torch.stack(model_output.logits).cpu().permute(1, 0, 2)
242
+ model_output.logits = None
243
+ if return_logprobs:
244
+ logprobs = torch.log_softmax(logits_list, dim=-1) #logprobs for each token
245
+ logprobs = torch.gather(logprobs, dim=-1, index = model_output.sequences[:, len(input_ids[0]):].unsqueeze(-1))#logprobs for each token in the generated text
246
+ logprobs = logprobs.squeeze(-1).tolist()#convert to list
247
+ logprobs = [logprobs[i][:indices[i]+1] for i in range(len(logprobs))]
248
+ if return_logits:
249
+ logits_list = [logits_list[i][:indices[i]+1] for i in range(len(logits_list))]
250
+ else:
251
+ logits_list = []
252
+ if return_activations:
253
+ activations_list = [] #shape = (num gen, num token, num_layer, hidden_state_shape)
254
+ for i in range(number_of_generations): #generation id
255
+ acts = []
256
+ for j in range(indices[i]+1): #token id
257
+ act = []
258
+ for k in range(len(model_output.hidden_states[0])): #layer id
259
+ act.append(model_output.hidden_states[j][k][i].cpu())
260
+ acts.append(act)
261
+ activations_list.append(acts)
262
+ model_output.hidden_states = None
263
+ if return_attentions:
264
+ attentions_list = model_output.attentions
265
+ for i in range(number_of_generations): #generation id
266
+ atts = []
267
+ for j in range(indices[i]+1): #token id
268
+ att = []
269
+ for k in range(len(model_output.attentions[0])): #layer id
270
+ att.append(model_output.attentions[j][k][i].cpu())
271
+ atts.append(att)
272
+ attentions_list.append(atts)
273
+ model_output.attentions = None
274
+
275
+ if not return_model_output:
276
+ model_output.sequences = None
277
+
278
+ return {"generated_texts": generated_texts, "logprobs": logprobs, "activations": activations_list, "logits":logits_list, "attentions":attentions_list, "model_outputs": model_output.sequences, "tokens":tokens}
279
+
280
+
281
+
282
+ def sample_generations_sequential_hf_local(model:PreTrainedModel, input_text:str, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
283
+ number_of_generations:int = 0, do_sample:bool=True, return_text:bool = False, return_logits:bool = False, return_logprobs:bool = False, return_attentions:bool = False, return_activations:bool = False, return_model_output:bool = True, **kwargs):
284
+
285
+ kwargs = copy.deepcopy(kwargs)
286
+ kwargs.pop('do_sample', None)
287
+ kwargs.pop('num_return_sequences', None)
288
+ kwargs.pop('return_dict_in_generate', None)
289
+ kwargs.pop('output_attentions', None)
290
+ kwargs.pop('output_hidden_states', None)
291
+ kwargs.pop('output_logits', None)
292
+ inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
293
+ input_ids = inputs['input_ids']
294
+
295
+ eos_token_id = kwargs.pop("eos_token_id", None)
296
+
297
+ if eos_token_id is None:
298
+ eos_token_id = model.config.eos_token_id
299
+
300
+ generated_texts = []
301
+ logits_list = []
302
+ logprobs_list = []
303
+ attentions_list = []
304
+ activations_list = []
305
+ model_outputs = []
306
+ token_lists = []
307
+ for i in range(number_of_generations):
308
+ with torch.no_grad():
309
+ model_output = model.generate(**inputs, num_return_sequences=1, do_sample=do_sample, return_dict_in_generate=True,
310
+ output_attentions=return_attentions, output_hidden_states=return_activations,
311
+ output_logits=(return_logits or return_logprobs), eos_token_id=eos_token_id, **kwargs)
312
+ model_output.past_key_values=None
313
+ model_output.sequences = model_output.sequences.cpu()
314
+ if return_model_output:
315
+ model_outputs.append(model_output.sequences)
316
+ if return_text:
317
+ tokens = model_output.sequences[0][len(input_ids[0]):]
318
+ generated_text = tokenizer.decode(tokens, skip_special_tokens=True)
319
+ generated_texts.append(generated_text)
320
+ token_lists.append(tokens.tolist())
321
+ if return_logprobs or return_logits:
322
+ logits = torch.cat(model_output.logits).cpu()
323
+ model_output.logits=None
324
+ if return_logprobs:
325
+ logprobs = torch.log_softmax(logits, dim=-1)#logprobs for each token
326
+ logprobs = torch.gather(logprobs, dim=1, index = model_output.sequences[0][len(input_ids[0]):].view(-1, 1))#logprobs for each token in the generated text
327
+ logprobs = logprobs.view(-1).tolist()#convert to list
328
+ logprobs_list.append(logprobs)
329
+ if return_logits:
330
+ logits_list.append(logits)
331
+ if return_activations:
332
+ acts = []
333
+ for i in range(len(model_output.hidden_states)):
334
+ act = []
335
+ for j in range(len(model_output.hidden_states[i])):
336
+ act.append(model_output.hidden_states[i][j][0].cpu())
337
+ acts.append(act)
338
+ activations_list.append(acts)
339
+ model_output.hidden_states = None
340
+ if return_attentions:
341
+ atts = []
342
+ for i in range(len(model_output.attentions)):
343
+ att = []
344
+ for j in range(len(model_output.attentions[i])):
345
+ att.append(model_output.attentions[i][j][0].cpu())
346
+ atts.append(att)
347
+ attentions_list.append(atts)
348
+ model_output.attentions = None
349
+
350
+ return {"generated_texts": generated_texts, "logprobs": logprobs_list, "activations": activations_list, "logits":logits_list, "attentions":attentions_list, "model_outputs": model_outputs, "tokens":token_lists}
351
+
352
+
353
+
354
+ def sample_generations_api(model:str, messages:list, generation_seed:int=None,
355
+ number_of_generations:int = 0, return_text:bool = False, return_logits:bool = False, return_logprobs:bool = False, return_attentions:bool = False, return_activations:bool = False, **kwargs):
356
+ #number_of_generations, return_text, return_logits, return_logprobs, return_attentions, return_activations = get_sampling_properties(truth_methods)
357
+
358
+ if number_of_generations == 0 or (not return_text and not return_logprobs):
359
+ return None
360
+
361
+
362
+ if generation_seed is not None:
363
+ random.seed(generation_seed)
364
+
365
+ kwargs = copy.deepcopy(kwargs)
366
+
367
+ generated_texts = []
368
+ logprobs_list = []
369
+ token_lists = []
370
+ for i in range(number_of_generations):
371
+ kwargs.pop('logprobs', None)
372
+ seed = kwargs.pop('seed', None)
373
+ seed = random.randint(0, 1000000)
374
+ kwargs['seed'] = seed
375
+
376
+ response = completion(
377
+ model=model,
378
+ messages=messages,
379
+ logprobs=return_logprobs,
380
+ **kwargs
381
+ )
382
+
383
+ if return_text:
384
+ generated_texts.append(response.choices[0].message['content'])
385
+ if return_logprobs:
386
+ logprobs_list.append([token['logprob'] for token in response.choices[0].logprobs['content']])
387
+ token_lists.append([token['token'] for token in response.choices[0].logprobs['content']])
388
+
389
+ return {"generated_texts": generated_texts, "logprobs": logprobs_list, "tokens":token_lists}
@@ -0,0 +1,5 @@
1
+ from .generation import long_form_generation_with_truth_value
2
+ from .decomposition_methods import *
3
+ from .statement_check_methods import *
4
+ from .evaluators import *
5
+ from .utils import *
@@ -0,0 +1,8 @@
1
+ from .decomposition_method import FactualDecompositionMethod
2
+ from .unstructured_decomposition_api import UnstructuredDecompositionAPI
3
+ from .unstructured_decomposition_local import UnstructuredDecompositionLocal
4
+ from .structured_decomposition_api import StructuredDecompositionAPI
5
+ from .structured_decomposition_local import StructuredDecompositionLocal
6
+
7
+
8
+ __all__ = ['UnstructuredDecompositionAPI', 'UnstructuredDecompositionLocal', 'StructuredDecompositionAPI', 'StructuredDecompositionLocal']