wisent 0.5.12__py3-none-any.whl → 0.5.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -1
- wisent/core/activations/__init__.py +26 -0
- wisent/core/activations/activations.py +96 -0
- wisent/core/activations/activations_collector.py +71 -20
- wisent/core/activations/prompt_construction_strategy.py +47 -0
- wisent/core/agent/budget.py +2 -2
- wisent/core/agent/device_benchmarks.py +1 -1
- wisent/core/agent/diagnose/classifier_marketplace.py +8 -8
- wisent/core/agent/diagnose/response_diagnostics.py +4 -4
- wisent/core/agent/diagnose/synthetic_classifier_option.py +1 -1
- wisent/core/agent/diagnose/tasks/task_manager.py +3 -3
- wisent/core/agent/diagnose.py +2 -1
- wisent/core/autonomous_agent.py +10 -2
- wisent/core/benchmark_extractors.py +293 -0
- wisent/core/bigcode_integration.py +20 -7
- wisent/core/branding.py +108 -0
- wisent/core/cli/__init__.py +15 -0
- wisent/core/cli/create_steering_vector.py +138 -0
- wisent/core/cli/evaluate_responses.py +715 -0
- wisent/core/cli/generate_pairs.py +128 -0
- wisent/core/cli/generate_pairs_from_task.py +119 -0
- wisent/core/cli/generate_responses.py +129 -0
- wisent/core/cli/generate_vector_from_synthetic.py +149 -0
- wisent/core/cli/generate_vector_from_task.py +147 -0
- wisent/core/cli/get_activations.py +191 -0
- wisent/core/cli/optimize_classification.py +339 -0
- wisent/core/cli/optimize_steering.py +364 -0
- wisent/core/cli/tasks.py +182 -0
- wisent/core/cli_logger.py +22 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +27 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +49 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +118 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +146 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +129 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gsm8k.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/livecodebench.py +367 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +121 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +121 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +110 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/squad2.py +124 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sst2.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +127 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_gen.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +117 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +117 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +127 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +1 -1
- wisent/core/data_loaders/__init__.py +235 -0
- wisent/core/data_loaders/loaders/lm_loader.py +2 -2
- wisent/core/data_loaders/loaders/task_interface_loader.py +300 -0
- wisent/{cli/data_loaders/data_loader_rotator.py → core/data_loaders/rotator.py} +1 -1
- wisent/core/download_full_benchmarks.py +79 -2
- wisent/core/evaluators/benchmark_specific/__init__.py +26 -0
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/evaluator.py +17 -17
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/cpp_sanitizer.py +2 -2
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/java_sanitizer.py +2 -2
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/python_sanitizer.py +2 -2
- wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/__init__.py +3 -0
- wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/provider.py +305 -0
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/runtime.py +36 -4
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/entrypoint.py +2 -4
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/recipes.py +1 -1
- wisent/core/evaluators/benchmark_specific/coding/solution_generator.py +258 -0
- wisent/core/evaluators/benchmark_specific/exact_match_evaluator.py +79 -0
- wisent/core/evaluators/benchmark_specific/f1_evaluator.py +101 -0
- wisent/core/evaluators/benchmark_specific/generation_evaluator.py +197 -0
- wisent/core/{log_likelihoods_evaluator.py → evaluators/benchmark_specific/log_likelihoods_evaluator.py} +10 -2
- wisent/core/evaluators/benchmark_specific/perplexity_evaluator.py +140 -0
- wisent/core/evaluators/benchmark_specific/personalization_evaluator.py +250 -0
- wisent/{cli/evaluators/evaluator_rotator.py → core/evaluators/rotator.py} +4 -4
- wisent/core/lm_eval_harness_ground_truth.py +3 -2
- wisent/core/main.py +57 -0
- wisent/core/model_persistence.py +2 -2
- wisent/core/models/wisent_model.py +6 -6
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
- wisent/core/optuna/steering/steering_optimization.py +1 -1
- wisent/core/parser_arguments/__init__.py +10 -0
- wisent/core/parser_arguments/agent_parser.py +110 -0
- wisent/core/parser_arguments/configure_model_parser.py +7 -0
- wisent/core/parser_arguments/create_steering_vector_parser.py +59 -0
- wisent/core/parser_arguments/evaluate_parser.py +40 -0
- wisent/core/parser_arguments/evaluate_responses_parser.py +10 -0
- wisent/core/parser_arguments/full_optimize_parser.py +115 -0
- wisent/core/parser_arguments/generate_pairs_from_task_parser.py +33 -0
- wisent/core/parser_arguments/generate_pairs_parser.py +29 -0
- wisent/core/parser_arguments/generate_responses_parser.py +15 -0
- wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +127 -0
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +127 -0
- wisent/core/parser_arguments/generate_vector_parser.py +90 -0
- wisent/core/parser_arguments/get_activations_parser.py +90 -0
- wisent/core/parser_arguments/main_parser.py +152 -0
- wisent/core/parser_arguments/model_config_parser.py +59 -0
- wisent/core/parser_arguments/monitor_parser.py +17 -0
- wisent/core/parser_arguments/multi_steer_parser.py +47 -0
- wisent/core/parser_arguments/optimize_classification_parser.py +67 -0
- wisent/core/parser_arguments/optimize_sample_size_parser.py +58 -0
- wisent/core/parser_arguments/optimize_steering_parser.py +147 -0
- wisent/core/parser_arguments/synthetic_parser.py +93 -0
- wisent/core/parser_arguments/tasks_parser.py +584 -0
- wisent/core/parser_arguments/test_nonsense_parser.py +26 -0
- wisent/core/parser_arguments/utils.py +111 -0
- wisent/core/prompts/core/prompt_formater.py +3 -3
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +2 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +2 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +2 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +2 -0
- wisent/{cli/steering_methods/steering_rotator.py → core/steering_methods/rotator.py} +4 -4
- wisent/core/steering_optimizer.py +45 -21
- wisent/{synthetic → core/synthetic}/cleaners/deduper_cleaner.py +3 -3
- wisent/{synthetic → core/synthetic}/cleaners/methods/base_dedupers.py +2 -2
- wisent/{synthetic → core/synthetic}/cleaners/methods/base_refusalers.py +1 -1
- wisent/{synthetic → core/synthetic}/cleaners/pairs_cleaner.py +5 -5
- wisent/{synthetic → core/synthetic}/cleaners/refusaler_cleaner.py +4 -4
- wisent/{synthetic → core/synthetic}/db_instructions/mini_dp.py +1 -1
- wisent/{synthetic → core/synthetic}/generators/diversities/methods/fast_diversity.py +1 -1
- wisent/{synthetic → core/synthetic}/generators/pairs_generator.py +38 -12
- wisent/core/tasks/livecodebench_task.py +4 -103
- wisent/core/timing_calibration.py +1 -1
- {wisent-0.5.12.dist-info → wisent-0.5.13.dist-info}/METADATA +3 -3
- wisent-0.5.13.dist-info/RECORD +294 -0
- wisent-0.5.13.dist-info/entry_points.txt +2 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +0 -53
- wisent/classifiers/core/atoms.py +0 -747
- wisent/classifiers/models/logistic.py +0 -29
- wisent/classifiers/models/mlp.py +0 -47
- wisent/cli/classifiers/classifier_rotator.py +0 -137
- wisent/cli/cli_logger.py +0 -142
- wisent/cli/wisent_cli/commands/help_cmd.py +0 -52
- wisent/cli/wisent_cli/commands/listing.py +0 -154
- wisent/cli/wisent_cli/commands/train_cmd.py +0 -322
- wisent/cli/wisent_cli/main.py +0 -93
- wisent/cli/wisent_cli/shell.py +0 -80
- wisent/cli/wisent_cli/ui.py +0 -69
- wisent/cli/wisent_cli/util/aggregations.py +0 -43
- wisent/cli/wisent_cli/util/parsing.py +0 -126
- wisent/cli/wisent_cli/version.py +0 -4
- wisent/opti/methods/__init__.py +0 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent-0.5.12.dist-info/RECORD +0 -220
- /wisent/{benchmarks → core/evaluators/benchmark_specific/coding}/__init__.py +0 -0
- /wisent/{benchmarks/coding → core/evaluators/benchmark_specific/coding/metrics}/__init__.py +0 -0
- /wisent/{benchmarks/coding/metrics → core/evaluators/benchmark_specific/coding/metrics/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/core/atoms.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/passk.py +0 -0
- /wisent/{benchmarks/coding/metrics/core → core/evaluators/benchmark_specific/coding/output_sanitizer}/__init__.py +0 -0
- /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/output_sanitizer/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/core/atoms.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/utils.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/__init__.py +0 -0
- /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/providers}/core/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/core/atoms.py +0 -0
- /wisent/{benchmarks/coding/providers/core → core/evaluators/benchmark_specific/coding/safe_docker}/__init__.py +0 -0
- /wisent/{benchmarks/coding/providers/livecodebench → core/evaluators/benchmark_specific/coding/safe_docker/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/atoms.py +0 -0
- /wisent/{benchmarks/coding/safe_docker → core/opti}/__init__.py +0 -0
- /wisent/{benchmarks/coding/safe_docker → core/opti}/core/__init__.py +0 -0
- /wisent/{opti → core/opti}/core/atoms.py +0 -0
- /wisent/{classifiers → core/opti/methods}/__init__.py +0 -0
- /wisent/{opti → core/opti}/methods/opti_classificator.py +0 -0
- /wisent/{opti → core/opti}/methods/opti_steering.py +0 -0
- /wisent/{classifiers/core → core/synthetic}/__init__.py +0 -0
- /wisent/{classifiers/models → core/synthetic/cleaners}/__init__.py +0 -0
- /wisent/{cli → core/synthetic/cleaners/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/cleaners/core/atoms.py +0 -0
- /wisent/{cli/classifiers → core/synthetic/cleaners/methods}/__init__.py +0 -0
- /wisent/{cli/data_loaders → core/synthetic/cleaners/methods/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/cleaners/methods/core/atoms.py +0 -0
- /wisent/{cli/evaluators → core/synthetic/db_instructions}/__init__.py +0 -0
- /wisent/{cli/steering_methods → core/synthetic/db_instructions/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/db_instructions/core/atoms.py +0 -0
- /wisent/{cli/wisent_cli → core/synthetic/generators}/__init__.py +0 -0
- /wisent/{cli/wisent_cli/commands → core/synthetic/generators/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/generators/core/atoms.py +0 -0
- /wisent/{cli/wisent_cli/util → core/synthetic/generators/diversities}/__init__.py +0 -0
- /wisent/{opti → core/synthetic/generators/diversities/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/generators/diversities/core/core.py +0 -0
- /wisent/{opti/core → core/synthetic/generators/diversities/methods}/__init__.py +0 -0
- {wisent-0.5.12.dist-info → wisent-0.5.13.dist-info}/WHEEL +0 -0
- {wisent-0.5.12.dist-info → wisent-0.5.13.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.5.12.dist-info → wisent-0.5.13.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,56 @@
|
|
|
1
1
|
__all__ = [
|
|
2
2
|
"EXTRACTORS",
|
|
3
3
|
]
|
|
4
|
-
base_import: str = "
|
|
4
|
+
base_import: str = "wisent.core.contrastive_pairs.lm_eval_pairs.lm_task_extractors."
|
|
5
5
|
EXTRACTORS: dict[str, str] = {
|
|
6
6
|
# key → "module_path:ClassName" (supports dotted attr path after ':')
|
|
7
|
+
"arc_challenge": f"{base_import}arc_challenge:Arc_ChallengeExtractor",
|
|
8
|
+
"arc_easy": f"{base_import}arc_easy:Arc_EasyExtractor",
|
|
9
|
+
"arithmetic": f"{base_import}arithmetic:ArithmeticExtractor",
|
|
10
|
+
"asdiv": f"{base_import}asdiv:ASDivExtractor",
|
|
11
|
+
"boolq": f"{base_import}boolq:BoolQExtractor",
|
|
12
|
+
"cb": f"{base_import}cb:CBExtractor",
|
|
13
|
+
"copa": f"{base_import}copa:COPAExtractor",
|
|
14
|
+
"coqa": f"{base_import}coqa:CoQAExtractor",
|
|
15
|
+
"drop": f"{base_import}drop:DROPExtractor",
|
|
16
|
+
"gsm8k": f"{base_import}gsm8k:GSM8KExtractor",
|
|
17
|
+
"headqa": f"{base_import}headqa:HeadQAExtractor",
|
|
18
|
+
"hellaswag": f"{base_import}hellaswag:HellaSwagExtractor",
|
|
19
|
+
"logiqa": f"{base_import}logiqa:LogiQAExtractor",
|
|
20
|
+
"logiqa2": f"{base_import}logiqa2:LogiQA2Extractor",
|
|
21
|
+
"mc_taco": f"{base_import}mc-taco:MCTACOExtractor",
|
|
22
|
+
"medqa": f"{base_import}medqa:MedQAExtractor",
|
|
23
|
+
"mrpc": f"{base_import}mrpc:MRPCExtractor",
|
|
24
|
+
"multirc": f"{base_import}multirc:MultiRCExtractor",
|
|
25
|
+
"mutual": f"{base_import}mutual:MutualExtractor",
|
|
26
|
+
"openbookqa": f"{base_import}openbookqa:OpenBookQAExtractor",
|
|
27
|
+
"pawsx": f"{base_import}pawsx:PAWSXExtractor",
|
|
28
|
+
"piqa": f"{base_import}piqa:PIQAExtractor",
|
|
29
|
+
"prost": f"{base_import}prost:PROSTExtractor",
|
|
30
|
+
"pubmedqa": f"{base_import}pubmedqa:PubMedQAExtractor",
|
|
31
|
+
"qa4mre": f"{base_import}qa4mre:QA4MREExtractor",
|
|
32
|
+
"qasper": f"{base_import}qasper:QASPERExtractor",
|
|
33
|
+
"qnli": f"{base_import}qnli:QNLIExtractor",
|
|
34
|
+
"qqp": f"{base_import}qqp:QQPExtractor",
|
|
35
|
+
"race": f"{base_import}race:RaceExtractor",
|
|
36
|
+
"record": f"{base_import}record:ReCoRDExtractor",
|
|
37
|
+
"rte": f"{base_import}rte:RTEExtractor",
|
|
38
|
+
"sciq": f"{base_import}sciq:SciQExtractor",
|
|
39
|
+
"social_iqa": f"{base_import}social_iqa:Social_IQAExtractor",
|
|
40
|
+
"squad2": f"{base_import}squad2:SQuAD2Extractor",
|
|
41
|
+
"sst2": f"{base_import}sst2:SST2Extractor",
|
|
42
|
+
"swag": f"{base_import}swag:SWAGExtractor",
|
|
43
|
+
"triviaqa": f"{base_import}triviaqa:TriviaQAExtractor",
|
|
44
|
+
"truthfulqa_gen": f"{base_import}truthfulqa_gen:TruthfulQA_GenExtractor",
|
|
45
|
+
"truthfulqa_mc1": f"{base_import}truthfulqa_mc1:TruthfulQA_MC1Extractor",
|
|
46
|
+
"truthfulqa_mc2": f"{base_import}truthfulqa_mc2:TruthfulQA_MC2Extractor",
|
|
47
|
+
"webqs": f"{base_import}webqs:WebQuestionsExtractor",
|
|
48
|
+
"wic": f"{base_import}wic:WiCExtractor",
|
|
7
49
|
"winogrande": f"{base_import}winogrande:WinograndeExtractor",
|
|
50
|
+
"wnli": f"{base_import}wnli:WNLIExtractor",
|
|
51
|
+
"wsc": f"{base_import}wsc:WSCExtractor",
|
|
52
|
+
"xnli": f"{base_import}xnli:XNLIExtractor",
|
|
53
|
+
"xstorycloze": f"{base_import}xstorycloze:XStoryClozeExtractor",
|
|
54
|
+
"xwinograd": f"{base_import}xwinograd:XWinogradExtractor",
|
|
55
|
+
"livecodebench": f"{base_import}livecodebench:LiveCodeBenchExtractor",
|
|
8
56
|
}
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
6
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
7
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
8
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from lm_eval.api.task import ConfigurableTask
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = ["Arc_ChallengeExtractor"]
|
|
15
|
+
_LOG = setup_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Arc_ChallengeExtractor(LMEvalBenchmarkExtractor):
|
|
19
|
+
"""Extractor for the Arc_Challenge benchmark."""
|
|
20
|
+
|
|
21
|
+
def extract_contrastive_pairs(
|
|
22
|
+
self,
|
|
23
|
+
lm_eval_task_data: ConfigurableTask,
|
|
24
|
+
limit: int | None = None,
|
|
25
|
+
) -> list[ContrastivePair]:
|
|
26
|
+
"""
|
|
27
|
+
Build contrastive pairs from Arc_Challenge docs.
|
|
28
|
+
|
|
29
|
+
Arc_Challenge schema:
|
|
30
|
+
- question
|
|
31
|
+
- choices: dict,
|
|
32
|
+
- choices["text"]: list with possible choices strings
|
|
33
|
+
- answerKey: str
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
lm_eval_task_data: lm-eval task instance for Arc_Challenge.
|
|
37
|
+
limit: Optional maximum number of pairs to produce.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A list of ContrastivePair objects.
|
|
41
|
+
"""
|
|
42
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
43
|
+
|
|
44
|
+
max_items = self._normalize_limit(limit)
|
|
45
|
+
docs = self.load_docs(lm_eval_task_data, max_items)
|
|
46
|
+
|
|
47
|
+
pairs: list[ContrastivePair] = []
|
|
48
|
+
|
|
49
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
50
|
+
|
|
51
|
+
for doc in docs:
|
|
52
|
+
pair = self._extract_pair_from_doc(doc)
|
|
53
|
+
if pair is not None:
|
|
54
|
+
pairs.append(pair)
|
|
55
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
56
|
+
break
|
|
57
|
+
|
|
58
|
+
if not pairs:
|
|
59
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
60
|
+
log.warning("No valid Arc_Challenge pairs extracted", extra={"task": task_name})
|
|
61
|
+
|
|
62
|
+
return pairs
|
|
63
|
+
|
|
64
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
65
|
+
"""
|
|
66
|
+
Convert a single Arc_Challenge doc into a ContrastivePair, if possible.
|
|
67
|
+
Returns None when required fields are missing or malformed.
|
|
68
|
+
"""
|
|
69
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
question = str(doc.get("question", "")).strip()
|
|
73
|
+
choices_dict = doc.get("choices", {})
|
|
74
|
+
choices = choices_dict["text"]
|
|
75
|
+
answer = str(doc.get("answerKey", "")).strip()
|
|
76
|
+
answer_idx = int(ord(answer) - ord('A'))
|
|
77
|
+
|
|
78
|
+
if not question or not choices or not (0 <= answer_idx < len(choices)):
|
|
79
|
+
log.debug(
|
|
80
|
+
"Skipping doc due to missing/invalid fields",
|
|
81
|
+
extra={"doc": doc},
|
|
82
|
+
)
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
correct = choices[answer_idx]
|
|
86
|
+
incorrect = choices[(answer_idx+1)%len(choices)]
|
|
87
|
+
|
|
88
|
+
question = f"{question}"
|
|
89
|
+
formatted_question = f"Question: {question}\nA. {incorrect}\nB. {correct}"
|
|
90
|
+
|
|
91
|
+
metadata = {
|
|
92
|
+
"label": "arc_easy",
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
return self._build_pair(
|
|
96
|
+
question=formatted_question,
|
|
97
|
+
correct=correct,
|
|
98
|
+
incorrect=incorrect,
|
|
99
|
+
metadata=metadata,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
except Exception as exc:
|
|
103
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
@staticmethod
|
|
107
|
+
def _build_pair(
|
|
108
|
+
question: str,
|
|
109
|
+
correct: str,
|
|
110
|
+
incorrect: str,
|
|
111
|
+
metadata: dict[str, Any] | None = None,
|
|
112
|
+
) -> ContrastivePair:
|
|
113
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
114
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
115
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
6
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
7
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
8
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from lm_eval.api.task import ConfigurableTask
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = ["Arc_EasyExtractor"]
|
|
15
|
+
_LOG = setup_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Arc_EasyExtractor(LMEvalBenchmarkExtractor):
|
|
19
|
+
"""Extractor for the Arc_Easy benchmark."""
|
|
20
|
+
|
|
21
|
+
def extract_contrastive_pairs(
|
|
22
|
+
self,
|
|
23
|
+
lm_eval_task_data: ConfigurableTask,
|
|
24
|
+
limit: int | None = None,
|
|
25
|
+
) -> list[ContrastivePair]:
|
|
26
|
+
"""
|
|
27
|
+
Build contrastive pairs from Arc_Easy docs.
|
|
28
|
+
|
|
29
|
+
Arc_Easy schema:
|
|
30
|
+
- question
|
|
31
|
+
- choices: dict,
|
|
32
|
+
- choices["text"]: list with possible choices strings
|
|
33
|
+
- answerKey: str
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
lm_eval_task_data: lm-eval task instance for Arc_Easy.
|
|
37
|
+
limit: Optional maximum number of pairs to produce.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A list of ContrastivePair objects.
|
|
41
|
+
"""
|
|
42
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
43
|
+
|
|
44
|
+
max_items = self._normalize_limit(limit)
|
|
45
|
+
docs = self.load_docs(lm_eval_task_data, max_items)
|
|
46
|
+
|
|
47
|
+
pairs: list[ContrastivePair] = []
|
|
48
|
+
|
|
49
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
50
|
+
|
|
51
|
+
for doc in docs:
|
|
52
|
+
pair = self._extract_pair_from_doc(doc)
|
|
53
|
+
if pair is not None:
|
|
54
|
+
pairs.append(pair)
|
|
55
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
56
|
+
break
|
|
57
|
+
|
|
58
|
+
if not pairs:
|
|
59
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
60
|
+
log.warning("No valid Arc_Easy pairs extracted", extra={"task": task_name})
|
|
61
|
+
|
|
62
|
+
return pairs
|
|
63
|
+
|
|
64
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
65
|
+
"""
|
|
66
|
+
Convert a single Arc_Easy doc into a ContrastivePair, if possible.
|
|
67
|
+
Returns None when required fields are missing or malformed.
|
|
68
|
+
"""
|
|
69
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
question = str(doc.get("question", "")).strip()
|
|
73
|
+
choices_dict = doc.get("choices", {})
|
|
74
|
+
choices = choices_dict["text"]
|
|
75
|
+
answer = str(doc.get("answerKey", "")).strip()
|
|
76
|
+
answer_idx = int(ord(answer) - ord('A'))
|
|
77
|
+
|
|
78
|
+
if not question or not choices or not (0 <= answer_idx < len(choices)):
|
|
79
|
+
log.debug(
|
|
80
|
+
"Skipping doc due to missing/invalid fields",
|
|
81
|
+
extra={"doc": doc},
|
|
82
|
+
)
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
correct = choices[answer_idx]
|
|
86
|
+
incorrect = choices[(answer_idx+1)%len(choices)]
|
|
87
|
+
|
|
88
|
+
question = f"{question}"
|
|
89
|
+
formatted_question = f"Question: {question}\nA. {incorrect}\nB. {correct}"
|
|
90
|
+
|
|
91
|
+
metadata = {
|
|
92
|
+
"label": "arc_easy",
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
return self._build_pair(
|
|
96
|
+
question=formatted_question,
|
|
97
|
+
correct=correct,
|
|
98
|
+
incorrect=incorrect,
|
|
99
|
+
metadata=metadata,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
except Exception as exc:
|
|
103
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
@staticmethod
|
|
107
|
+
def _build_pair(
|
|
108
|
+
question: str,
|
|
109
|
+
correct: str,
|
|
110
|
+
incorrect: str,
|
|
111
|
+
metadata: dict[str, Any] | None = None,
|
|
112
|
+
) -> ContrastivePair:
|
|
113
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
114
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
115
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
from typing import Any, TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
7
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
8
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
9
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from lm_eval.api.task import ConfigurableTask
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
__all__ = ["ArithmeticExtractor"]
|
|
16
|
+
_LOG = setup_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ArithmeticExtractor(LMEvalBenchmarkExtractor):
|
|
20
|
+
"""Extractor for the Arithmetic benchmark."""
|
|
21
|
+
|
|
22
|
+
def extract_contrastive_pairs(
|
|
23
|
+
self,
|
|
24
|
+
lm_eval_task_data: ConfigurableTask,
|
|
25
|
+
limit: int | None = None,
|
|
26
|
+
) -> list[ContrastivePair]:
|
|
27
|
+
"""
|
|
28
|
+
Build contrastive pairs from Arithmetic docs.
|
|
29
|
+
|
|
30
|
+
Arithmetic schema:
|
|
31
|
+
- context: str
|
|
32
|
+
- completion: str
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
lm_eval_task_data: lm-eval task instance for Arithmetic.
|
|
36
|
+
limit: Optional maximum number of pairs to produce.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
A list of ContrastivePair objects.
|
|
40
|
+
"""
|
|
41
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
42
|
+
|
|
43
|
+
max_items = self._normalize_limit(limit)
|
|
44
|
+
docs = self.load_docs(lm_eval_task_data, max_items)
|
|
45
|
+
|
|
46
|
+
pairs: list[ContrastivePair] = []
|
|
47
|
+
|
|
48
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
49
|
+
|
|
50
|
+
for doc in docs:
|
|
51
|
+
pair = self._extract_pair_from_doc(doc)
|
|
52
|
+
if pair is not None:
|
|
53
|
+
pairs.append(pair)
|
|
54
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
55
|
+
break
|
|
56
|
+
|
|
57
|
+
if not pairs:
|
|
58
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
59
|
+
log.warning("No valid Arithmetic pairs extracted", extra={"task": task_name})
|
|
60
|
+
|
|
61
|
+
return pairs
|
|
62
|
+
|
|
63
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
64
|
+
"""
|
|
65
|
+
Convert a single Arithmetic doc into a ContrastivePair, if possible.
|
|
66
|
+
Returns None when required fields are missing or malformed.
|
|
67
|
+
"""
|
|
68
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
69
|
+
|
|
70
|
+
try:
|
|
71
|
+
context = str(doc.get("context", "")).strip()
|
|
72
|
+
completion = str(doc.get("completion", "")).strip()
|
|
73
|
+
|
|
74
|
+
if not context or not completion:
|
|
75
|
+
log.debug(
|
|
76
|
+
"Skipping doc due to missing/invalid fields",
|
|
77
|
+
extra={"doc": doc},
|
|
78
|
+
)
|
|
79
|
+
return None
|
|
80
|
+
|
|
81
|
+
correct = completion
|
|
82
|
+
incorrect_val = float(completion) + 1
|
|
83
|
+
incorrect = str(int(incorrect_val)) if incorrect_val == int(incorrect_val) else str(incorrect_val)
|
|
84
|
+
|
|
85
|
+
formatted_question = f"{context}\nA. {incorrect}\nB. {correct}"
|
|
86
|
+
|
|
87
|
+
metadata = {
|
|
88
|
+
"label": "arithmetic",
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
return self._build_pair(
|
|
92
|
+
question=formatted_question,
|
|
93
|
+
correct=correct,
|
|
94
|
+
incorrect=incorrect,
|
|
95
|
+
metadata=metadata,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
except Exception as exc:
|
|
99
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
@staticmethod
|
|
103
|
+
def _build_pair(
|
|
104
|
+
question: str,
|
|
105
|
+
correct: str,
|
|
106
|
+
incorrect: str,
|
|
107
|
+
metadata: dict[str, Any] | None = None,
|
|
108
|
+
) -> ContrastivePair:
|
|
109
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
110
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
111
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
import re
|
|
5
|
+
from typing import Any, TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
8
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
9
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
10
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from lm_eval.api.task import ConfigurableTask
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
__all__ = ["ASDivExtractor"]
|
|
17
|
+
_LOG = setup_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ASDivExtractor(LMEvalBenchmarkExtractor):
|
|
21
|
+
"""Extractor for the ASDiv benchmark."""
|
|
22
|
+
|
|
23
|
+
def extract_contrastive_pairs(
|
|
24
|
+
self,
|
|
25
|
+
lm_eval_task_data: ConfigurableTask,
|
|
26
|
+
limit: int | None = None,
|
|
27
|
+
) -> list[ContrastivePair]:
|
|
28
|
+
"""
|
|
29
|
+
Build contrastive pairs from ASDiv docs.
|
|
30
|
+
|
|
31
|
+
ASDiv schema:
|
|
32
|
+
- body: str
|
|
33
|
+
- question: str
|
|
34
|
+
- answer: str
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
lm_eval_task_data: lm-eval task instance for ASDiv.
|
|
38
|
+
limit: Optional maximum number of pairs to produce.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
A list of ContrastivePair objects.
|
|
42
|
+
"""
|
|
43
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
44
|
+
|
|
45
|
+
max_items = self._normalize_limit(limit)
|
|
46
|
+
docs = self.load_docs(lm_eval_task_data, max_items)
|
|
47
|
+
|
|
48
|
+
pairs: list[ContrastivePair] = []
|
|
49
|
+
|
|
50
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
51
|
+
|
|
52
|
+
for doc in docs:
|
|
53
|
+
pair = self._extract_pair_from_doc(doc)
|
|
54
|
+
if pair is not None:
|
|
55
|
+
pairs.append(pair)
|
|
56
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
57
|
+
break
|
|
58
|
+
|
|
59
|
+
if not pairs:
|
|
60
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
61
|
+
log.warning("No valid ASDiv pairs extracted", extra={"task": task_name})
|
|
62
|
+
|
|
63
|
+
return pairs
|
|
64
|
+
|
|
65
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
66
|
+
"""
|
|
67
|
+
Convert a single ASDiv doc into a ContrastivePair, if possible.
|
|
68
|
+
Returns None when required fields are missing or malformed.
|
|
69
|
+
"""
|
|
70
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
body = str(doc.get("body", "")).strip()
|
|
74
|
+
question = str(doc.get("question", "")).strip()
|
|
75
|
+
answer = str(doc.get("answer", "")).strip()
|
|
76
|
+
|
|
77
|
+
if not question or not answer:
|
|
78
|
+
log.debug(
|
|
79
|
+
"Skipping doc due to missing/invalid fields",
|
|
80
|
+
extra={"doc": doc},
|
|
81
|
+
)
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
match = re.search(r'-?\d+(?:\.\d+)?', answer.replace(',', ''))
|
|
85
|
+
if not match:
|
|
86
|
+
log.debug("Skipping doc due to missing numerical answer", extra={"doc": doc})
|
|
87
|
+
return None
|
|
88
|
+
numerical_answer = match.group()
|
|
89
|
+
correct = numerical_answer
|
|
90
|
+
incorrect_val = float(numerical_answer) + 1
|
|
91
|
+
incorrect = str(int(incorrect_val)) if incorrect_val == int(incorrect_val) else str(incorrect_val)
|
|
92
|
+
|
|
93
|
+
formatted_question = f"{body}\nQuestion:{question}\nA. {incorrect}\nB. {correct}"
|
|
94
|
+
|
|
95
|
+
metadata = {
|
|
96
|
+
"label": "asdiv",
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
return self._build_pair(
|
|
100
|
+
question=formatted_question,
|
|
101
|
+
correct=correct,
|
|
102
|
+
incorrect=incorrect,
|
|
103
|
+
metadata=metadata,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
except Exception as exc:
|
|
107
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
108
|
+
return None
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def _build_pair(
|
|
112
|
+
question: str,
|
|
113
|
+
correct: str,
|
|
114
|
+
incorrect: str,
|
|
115
|
+
metadata: dict[str, Any] | None = None,
|
|
116
|
+
) -> ContrastivePair:
|
|
117
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
118
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
119
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
from typing import Any, TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
7
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
8
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
9
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from lm_eval.api.task import ConfigurableTask
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
__all__ = ["BoolQExtractor"]
|
|
16
|
+
_LOG = setup_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BoolQExtractor(LMEvalBenchmarkExtractor):
|
|
20
|
+
"""Extractor for the BoolQ benchmark."""
|
|
21
|
+
|
|
22
|
+
def extract_contrastive_pairs(
|
|
23
|
+
self,
|
|
24
|
+
lm_eval_task_data: ConfigurableTask,
|
|
25
|
+
limit: int | None = None,
|
|
26
|
+
preferred_doc: str | None = None,
|
|
27
|
+
) -> list[ContrastivePair]:
|
|
28
|
+
"""
|
|
29
|
+
Build contrastive pairs from BoolQ docs.
|
|
30
|
+
|
|
31
|
+
BoolQ schema:
|
|
32
|
+
- passage: str
|
|
33
|
+
- question: str
|
|
34
|
+
- label: 0 or 1
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
lm_eval_task_data: lm-eval task instance for BoolQ.
|
|
38
|
+
limit: Optional maximum number of pairs to produce.
|
|
39
|
+
preferred_doc: Preferred document source ("validation", "test", "training", "fewshot")
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
A list of ContrastivePair objects.
|
|
43
|
+
"""
|
|
44
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
45
|
+
|
|
46
|
+
max_items = self._normalize_limit(limit)
|
|
47
|
+
docs = self.load_docs(lm_eval_task_data, max_items, preferred_doc=preferred_doc)
|
|
48
|
+
|
|
49
|
+
pairs: list[ContrastivePair] = []
|
|
50
|
+
|
|
51
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
52
|
+
|
|
53
|
+
for doc in docs:
|
|
54
|
+
pair = self._extract_pair_from_doc(doc)
|
|
55
|
+
if pair is not None:
|
|
56
|
+
pairs.append(pair)
|
|
57
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
58
|
+
break
|
|
59
|
+
|
|
60
|
+
if not pairs:
|
|
61
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
62
|
+
log.warning("No valid BoolQ pairs extracted", extra={"task": task_name})
|
|
63
|
+
|
|
64
|
+
return pairs
|
|
65
|
+
|
|
66
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
67
|
+
"""
|
|
68
|
+
Convert a single BoolQ doc into a ContrastivePair, if possible.
|
|
69
|
+
Returns None when required fields are missing or malformed.
|
|
70
|
+
"""
|
|
71
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
passage = str(doc.get("passage", "")).strip()
|
|
75
|
+
question = str(doc.get("question", "")).strip()
|
|
76
|
+
label = doc.get("label")
|
|
77
|
+
|
|
78
|
+
if not passage or not question or label not in {0, 1}:
|
|
79
|
+
log.debug(
|
|
80
|
+
"Skipping doc due to missing/invalid fields",
|
|
81
|
+
extra={"doc": doc},
|
|
82
|
+
)
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
#formatted_question = f"{passage}\nQuestion: {question}?\nAnswer:\nA. Yes\nB. No"
|
|
86
|
+
formatted_question = f"{passage}\nQuestion: {question}?"
|
|
87
|
+
|
|
88
|
+
correct = "Yes" if label == 1 else "No"
|
|
89
|
+
incorrect = "No" if label == 1 else "Yes"
|
|
90
|
+
|
|
91
|
+
metadata = {
|
|
92
|
+
"label": "boolq",
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
return self._build_pair(
|
|
96
|
+
question=formatted_question,
|
|
97
|
+
correct=correct,
|
|
98
|
+
incorrect=incorrect,
|
|
99
|
+
metadata=metadata,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
except Exception as exc:
|
|
103
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
@staticmethod
|
|
107
|
+
def _build_pair(
|
|
108
|
+
question: str,
|
|
109
|
+
correct: str,
|
|
110
|
+
incorrect: str,
|
|
111
|
+
metadata: dict[str, Any] | None = None,
|
|
112
|
+
) -> ContrastivePair:
|
|
113
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
114
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
115
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|