wisent 0.5.12__py3-none-any.whl → 0.5.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -1
- wisent/core/activations/__init__.py +26 -0
- wisent/core/activations/activations.py +96 -0
- wisent/core/activations/activations_collector.py +71 -20
- wisent/core/activations/prompt_construction_strategy.py +47 -0
- wisent/core/agent/budget.py +2 -2
- wisent/core/agent/device_benchmarks.py +1 -1
- wisent/core/agent/diagnose/classifier_marketplace.py +8 -8
- wisent/core/agent/diagnose/response_diagnostics.py +4 -4
- wisent/core/agent/diagnose/synthetic_classifier_option.py +1 -1
- wisent/core/agent/diagnose/tasks/task_manager.py +3 -3
- wisent/core/agent/diagnose.py +2 -1
- wisent/core/autonomous_agent.py +10 -2
- wisent/core/benchmark_extractors.py +293 -0
- wisent/core/bigcode_integration.py +20 -7
- wisent/core/branding.py +108 -0
- wisent/core/cli/__init__.py +15 -0
- wisent/core/cli/create_steering_vector.py +138 -0
- wisent/core/cli/evaluate_responses.py +715 -0
- wisent/core/cli/generate_pairs.py +128 -0
- wisent/core/cli/generate_pairs_from_task.py +119 -0
- wisent/core/cli/generate_responses.py +129 -0
- wisent/core/cli/generate_vector_from_synthetic.py +149 -0
- wisent/core/cli/generate_vector_from_task.py +147 -0
- wisent/core/cli/get_activations.py +191 -0
- wisent/core/cli/optimize_classification.py +339 -0
- wisent/core/cli/optimize_steering.py +364 -0
- wisent/core/cli/tasks.py +182 -0
- wisent/core/cli_logger.py +22 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +27 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +49 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +118 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +146 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +129 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gsm8k.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/livecodebench.py +367 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +121 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +121 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +110 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/squad2.py +124 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sst2.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +127 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_gen.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +117 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +117 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +127 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +1 -1
- wisent/core/data_loaders/__init__.py +235 -0
- wisent/core/data_loaders/loaders/lm_loader.py +2 -2
- wisent/core/data_loaders/loaders/task_interface_loader.py +300 -0
- wisent/{cli/data_loaders/data_loader_rotator.py → core/data_loaders/rotator.py} +1 -1
- wisent/core/download_full_benchmarks.py +79 -2
- wisent/core/evaluators/benchmark_specific/__init__.py +26 -0
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/evaluator.py +17 -17
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/cpp_sanitizer.py +2 -2
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/java_sanitizer.py +2 -2
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/python_sanitizer.py +2 -2
- wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/__init__.py +3 -0
- wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/provider.py +305 -0
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/runtime.py +36 -4
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/entrypoint.py +2 -4
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/recipes.py +1 -1
- wisent/core/evaluators/benchmark_specific/coding/solution_generator.py +258 -0
- wisent/core/evaluators/benchmark_specific/exact_match_evaluator.py +79 -0
- wisent/core/evaluators/benchmark_specific/f1_evaluator.py +101 -0
- wisent/core/evaluators/benchmark_specific/generation_evaluator.py +197 -0
- wisent/core/{log_likelihoods_evaluator.py → evaluators/benchmark_specific/log_likelihoods_evaluator.py} +10 -2
- wisent/core/evaluators/benchmark_specific/perplexity_evaluator.py +140 -0
- wisent/core/evaluators/benchmark_specific/personalization_evaluator.py +250 -0
- wisent/{cli/evaluators/evaluator_rotator.py → core/evaluators/rotator.py} +4 -4
- wisent/core/lm_eval_harness_ground_truth.py +3 -2
- wisent/core/main.py +57 -0
- wisent/core/model_persistence.py +2 -2
- wisent/core/models/wisent_model.py +6 -6
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
- wisent/core/optuna/steering/steering_optimization.py +1 -1
- wisent/core/parser_arguments/__init__.py +10 -0
- wisent/core/parser_arguments/agent_parser.py +110 -0
- wisent/core/parser_arguments/configure_model_parser.py +7 -0
- wisent/core/parser_arguments/create_steering_vector_parser.py +59 -0
- wisent/core/parser_arguments/evaluate_parser.py +40 -0
- wisent/core/parser_arguments/evaluate_responses_parser.py +10 -0
- wisent/core/parser_arguments/full_optimize_parser.py +115 -0
- wisent/core/parser_arguments/generate_pairs_from_task_parser.py +33 -0
- wisent/core/parser_arguments/generate_pairs_parser.py +29 -0
- wisent/core/parser_arguments/generate_responses_parser.py +15 -0
- wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +127 -0
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +127 -0
- wisent/core/parser_arguments/generate_vector_parser.py +90 -0
- wisent/core/parser_arguments/get_activations_parser.py +90 -0
- wisent/core/parser_arguments/main_parser.py +152 -0
- wisent/core/parser_arguments/model_config_parser.py +59 -0
- wisent/core/parser_arguments/monitor_parser.py +17 -0
- wisent/core/parser_arguments/multi_steer_parser.py +47 -0
- wisent/core/parser_arguments/optimize_classification_parser.py +67 -0
- wisent/core/parser_arguments/optimize_sample_size_parser.py +58 -0
- wisent/core/parser_arguments/optimize_steering_parser.py +147 -0
- wisent/core/parser_arguments/synthetic_parser.py +93 -0
- wisent/core/parser_arguments/tasks_parser.py +584 -0
- wisent/core/parser_arguments/test_nonsense_parser.py +26 -0
- wisent/core/parser_arguments/utils.py +111 -0
- wisent/core/prompts/core/prompt_formater.py +3 -3
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +2 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +2 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +2 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +2 -0
- wisent/{cli/steering_methods/steering_rotator.py → core/steering_methods/rotator.py} +4 -4
- wisent/core/steering_optimizer.py +45 -21
- wisent/{synthetic → core/synthetic}/cleaners/deduper_cleaner.py +3 -3
- wisent/{synthetic → core/synthetic}/cleaners/methods/base_dedupers.py +2 -2
- wisent/{synthetic → core/synthetic}/cleaners/methods/base_refusalers.py +1 -1
- wisent/{synthetic → core/synthetic}/cleaners/pairs_cleaner.py +5 -5
- wisent/{synthetic → core/synthetic}/cleaners/refusaler_cleaner.py +4 -4
- wisent/{synthetic → core/synthetic}/db_instructions/mini_dp.py +1 -1
- wisent/{synthetic → core/synthetic}/generators/diversities/methods/fast_diversity.py +1 -1
- wisent/{synthetic → core/synthetic}/generators/pairs_generator.py +38 -12
- wisent/core/tasks/livecodebench_task.py +4 -103
- wisent/core/timing_calibration.py +1 -1
- {wisent-0.5.12.dist-info → wisent-0.5.13.dist-info}/METADATA +3 -3
- wisent-0.5.13.dist-info/RECORD +294 -0
- wisent-0.5.13.dist-info/entry_points.txt +2 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +0 -53
- wisent/classifiers/core/atoms.py +0 -747
- wisent/classifiers/models/logistic.py +0 -29
- wisent/classifiers/models/mlp.py +0 -47
- wisent/cli/classifiers/classifier_rotator.py +0 -137
- wisent/cli/cli_logger.py +0 -142
- wisent/cli/wisent_cli/commands/help_cmd.py +0 -52
- wisent/cli/wisent_cli/commands/listing.py +0 -154
- wisent/cli/wisent_cli/commands/train_cmd.py +0 -322
- wisent/cli/wisent_cli/main.py +0 -93
- wisent/cli/wisent_cli/shell.py +0 -80
- wisent/cli/wisent_cli/ui.py +0 -69
- wisent/cli/wisent_cli/util/aggregations.py +0 -43
- wisent/cli/wisent_cli/util/parsing.py +0 -126
- wisent/cli/wisent_cli/version.py +0 -4
- wisent/opti/methods/__init__.py +0 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent-0.5.12.dist-info/RECORD +0 -220
- /wisent/{benchmarks → core/evaluators/benchmark_specific/coding}/__init__.py +0 -0
- /wisent/{benchmarks/coding → core/evaluators/benchmark_specific/coding/metrics}/__init__.py +0 -0
- /wisent/{benchmarks/coding/metrics → core/evaluators/benchmark_specific/coding/metrics/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/core/atoms.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/passk.py +0 -0
- /wisent/{benchmarks/coding/metrics/core → core/evaluators/benchmark_specific/coding/output_sanitizer}/__init__.py +0 -0
- /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/output_sanitizer/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/core/atoms.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/utils.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/__init__.py +0 -0
- /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/providers}/core/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/core/atoms.py +0 -0
- /wisent/{benchmarks/coding/providers/core → core/evaluators/benchmark_specific/coding/safe_docker}/__init__.py +0 -0
- /wisent/{benchmarks/coding/providers/livecodebench → core/evaluators/benchmark_specific/coding/safe_docker/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/atoms.py +0 -0
- /wisent/{benchmarks/coding/safe_docker → core/opti}/__init__.py +0 -0
- /wisent/{benchmarks/coding/safe_docker → core/opti}/core/__init__.py +0 -0
- /wisent/{opti → core/opti}/core/atoms.py +0 -0
- /wisent/{classifiers → core/opti/methods}/__init__.py +0 -0
- /wisent/{opti → core/opti}/methods/opti_classificator.py +0 -0
- /wisent/{opti → core/opti}/methods/opti_steering.py +0 -0
- /wisent/{classifiers/core → core/synthetic}/__init__.py +0 -0
- /wisent/{classifiers/models → core/synthetic/cleaners}/__init__.py +0 -0
- /wisent/{cli → core/synthetic/cleaners/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/cleaners/core/atoms.py +0 -0
- /wisent/{cli/classifiers → core/synthetic/cleaners/methods}/__init__.py +0 -0
- /wisent/{cli/data_loaders → core/synthetic/cleaners/methods/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/cleaners/methods/core/atoms.py +0 -0
- /wisent/{cli/evaluators → core/synthetic/db_instructions}/__init__.py +0 -0
- /wisent/{cli/steering_methods → core/synthetic/db_instructions/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/db_instructions/core/atoms.py +0 -0
- /wisent/{cli/wisent_cli → core/synthetic/generators}/__init__.py +0 -0
- /wisent/{cli/wisent_cli/commands → core/synthetic/generators/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/generators/core/atoms.py +0 -0
- /wisent/{cli/wisent_cli/util → core/synthetic/generators/diversities}/__init__.py +0 -0
- /wisent/{opti → core/synthetic/generators/diversities/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/generators/diversities/core/core.py +0 -0
- /wisent/{opti/core → core/synthetic/generators/diversities/methods}/__init__.py +0 -0
- {wisent-0.5.12.dist-info → wisent-0.5.13.dist-info}/WHEEL +0 -0
- {wisent-0.5.12.dist-info → wisent-0.5.13.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.5.12.dist-info → wisent-0.5.13.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
6
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
7
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
8
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from lm_eval.api.task import ConfigurableTask
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = ["LogiQAExtractor"]
|
|
15
|
+
_LOG = setup_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LogiQAExtractor(LMEvalBenchmarkExtractor):
|
|
19
|
+
"""Extractor for the LogiQA benchmark."""
|
|
20
|
+
|
|
21
|
+
def extract_contrastive_pairs(
|
|
22
|
+
self,
|
|
23
|
+
lm_eval_task_data: ConfigurableTask,
|
|
24
|
+
limit: int | None = None,
|
|
25
|
+
) -> list[ContrastivePair]:
|
|
26
|
+
"""
|
|
27
|
+
Build contrastive pairs from LogiQA docs.
|
|
28
|
+
|
|
29
|
+
LogiQA schema:
|
|
30
|
+
- context: str
|
|
31
|
+
- question: str
|
|
32
|
+
- options: list
|
|
33
|
+
- label: "a", "b", "c", "d"
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
lm_eval_task_data: lm-eval task instance for LogiQA.
|
|
37
|
+
limit: Optional maximum number of pairs to produce.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A list of ContrastivePair objects.
|
|
41
|
+
"""
|
|
42
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
43
|
+
|
|
44
|
+
max_items = self._normalize_limit(limit)
|
|
45
|
+
docs = self.load_docs(lm_eval_task_data, max_items)
|
|
46
|
+
|
|
47
|
+
pairs: list[ContrastivePair] = []
|
|
48
|
+
|
|
49
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
50
|
+
|
|
51
|
+
for doc in docs:
|
|
52
|
+
pair = self._extract_pair_from_doc(doc)
|
|
53
|
+
if pair is not None:
|
|
54
|
+
pairs.append(pair)
|
|
55
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
56
|
+
break
|
|
57
|
+
|
|
58
|
+
if not pairs:
|
|
59
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
60
|
+
log.warning("No valid LogiQA pairs extracted", extra={"task": task_name})
|
|
61
|
+
|
|
62
|
+
return pairs
|
|
63
|
+
|
|
64
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
65
|
+
"""
|
|
66
|
+
Convert a single LogiQA doc into a ContrastivePair, if possible.
|
|
67
|
+
Returns None when required fields are missing or malformed.
|
|
68
|
+
"""
|
|
69
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
context= str(doc.get("context", "")).strip()
|
|
73
|
+
question = str(doc.get("question", "")).strip()
|
|
74
|
+
options = doc.get("options", [])
|
|
75
|
+
label = str(doc.get("label", "")).strip()
|
|
76
|
+
label_idx = int(ord(label) - ord('a'))
|
|
77
|
+
|
|
78
|
+
if not context or not question or not (0 <= label_idx < len(options)):
|
|
79
|
+
log.debug(
|
|
80
|
+
"Skipping doc due to missing/invalid fields",
|
|
81
|
+
extra={"doc": doc},
|
|
82
|
+
)
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
correct = options[label_idx]
|
|
86
|
+
incorrect = options[(label_idx+1)%len(options)]
|
|
87
|
+
|
|
88
|
+
question = f"{question}"
|
|
89
|
+
formatted_question = f"Passage: {context}\nQuestion: {question}\nA. {incorrect}\nB. {correct}"
|
|
90
|
+
|
|
91
|
+
metadata = {
|
|
92
|
+
"label": "logiqa",
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
return self._build_pair(
|
|
96
|
+
question=formatted_question,
|
|
97
|
+
correct=correct,
|
|
98
|
+
incorrect=incorrect,
|
|
99
|
+
metadata=metadata,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
except Exception as exc:
|
|
103
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
@staticmethod
|
|
107
|
+
def _build_pair(
|
|
108
|
+
question: str,
|
|
109
|
+
correct: str,
|
|
110
|
+
incorrect: str,
|
|
111
|
+
metadata: dict[str, Any] | None = None,
|
|
112
|
+
) -> ContrastivePair:
|
|
113
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
114
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
115
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
6
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
7
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
8
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from lm_eval.api.task import ConfigurableTask
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = ["LogiQA2Extractor"]
|
|
15
|
+
_LOG = setup_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LogiQA2Extractor(LMEvalBenchmarkExtractor):
|
|
19
|
+
"""Extractor for the LogiQA2 benchmark."""
|
|
20
|
+
|
|
21
|
+
def extract_contrastive_pairs(
|
|
22
|
+
self,
|
|
23
|
+
lm_eval_task_data: ConfigurableTask,
|
|
24
|
+
limit: int | None = None,
|
|
25
|
+
) -> list[ContrastivePair]:
|
|
26
|
+
"""
|
|
27
|
+
Build contrastive pairs from LogiQA2 docs.
|
|
28
|
+
|
|
29
|
+
LogiQA2 schema:
|
|
30
|
+
- text: str
|
|
31
|
+
- question: str
|
|
32
|
+
- options: list
|
|
33
|
+
- label: 0, 1, 2, 3
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
lm_eval_task_data: lm-eval task instance for LogiQA2.
|
|
37
|
+
limit: Optional maximum number of pairs to produce.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A list of ContrastivePair objects.
|
|
41
|
+
"""
|
|
42
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
43
|
+
|
|
44
|
+
max_items = self._normalize_limit(limit)
|
|
45
|
+
docs = self.load_docs(lm_eval_task_data, max_items)
|
|
46
|
+
|
|
47
|
+
pairs: list[ContrastivePair] = []
|
|
48
|
+
|
|
49
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
50
|
+
|
|
51
|
+
for doc in docs:
|
|
52
|
+
pair = self._extract_pair_from_doc(doc)
|
|
53
|
+
if pair is not None:
|
|
54
|
+
pairs.append(pair)
|
|
55
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
56
|
+
break
|
|
57
|
+
|
|
58
|
+
if not pairs:
|
|
59
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
60
|
+
log.warning("No valid LogiQA2 pairs extracted", extra={"task": task_name})
|
|
61
|
+
|
|
62
|
+
return pairs
|
|
63
|
+
|
|
64
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
65
|
+
"""
|
|
66
|
+
Convert a single LogiQA2 doc into a ContrastivePair, if possible.
|
|
67
|
+
Returns None when required fields are missing or malformed.
|
|
68
|
+
"""
|
|
69
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
text= str(doc.get("text", "")).strip()
|
|
73
|
+
question = str(doc.get("question", "")).strip()
|
|
74
|
+
options = doc.get("options", [])
|
|
75
|
+
answer = doc.get("answer")
|
|
76
|
+
|
|
77
|
+
if not text or not question or not (0 <= answer < len(options)):
|
|
78
|
+
log.debug(
|
|
79
|
+
"Skipping doc due to missing/invalid fields",
|
|
80
|
+
extra={"doc": doc},
|
|
81
|
+
)
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
correct = options[answer]
|
|
85
|
+
incorrect = options[(answer+1)%len(options)]
|
|
86
|
+
|
|
87
|
+
question = f"{question}"
|
|
88
|
+
formatted_question = f"Passage: {text}\nQuestion: {question}\nA. {incorrect}\nB. {correct}"
|
|
89
|
+
|
|
90
|
+
metadata = {
|
|
91
|
+
"label": "logiqa2",
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
return self._build_pair(
|
|
95
|
+
question=formatted_question,
|
|
96
|
+
correct=correct,
|
|
97
|
+
incorrect=incorrect,
|
|
98
|
+
metadata=metadata,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
except Exception as exc:
|
|
102
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
@staticmethod
|
|
106
|
+
def _build_pair(
|
|
107
|
+
question: str,
|
|
108
|
+
correct: str,
|
|
109
|
+
incorrect: str,
|
|
110
|
+
metadata: dict[str, Any] | None = None,
|
|
111
|
+
) -> ContrastivePair:
|
|
112
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
113
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
114
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
6
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
7
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
8
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from lm_eval.api.task import ConfigurableTask
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = ["MCTACOExtractor"]
|
|
15
|
+
_LOG = setup_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MCTACOExtractor(LMEvalBenchmarkExtractor):
|
|
19
|
+
"""Extractor for the MC-TACO benchmark."""
|
|
20
|
+
|
|
21
|
+
def extract_contrastive_pairs(
|
|
22
|
+
self,
|
|
23
|
+
lm_eval_task_data: ConfigurableTask,
|
|
24
|
+
limit: int | None = None,
|
|
25
|
+
) -> list[ContrastivePair]:
|
|
26
|
+
"""
|
|
27
|
+
Build contrastive pairs from MC-TACO docs.
|
|
28
|
+
|
|
29
|
+
MC-TACO schema:
|
|
30
|
+
- sentence: str
|
|
31
|
+
- question: str
|
|
32
|
+
- answer: str
|
|
33
|
+
- label: 0 or 1
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
lm_eval_task_data: lm-eval task instance for MC-TACO.
|
|
37
|
+
limit: Optional maximum number of pairs to produce.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A list of ContrastivePair objects.
|
|
41
|
+
"""
|
|
42
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
43
|
+
|
|
44
|
+
max_items = self._normalize_limit(limit)
|
|
45
|
+
docs = self.load_docs(lm_eval_task_data, max_items)
|
|
46
|
+
|
|
47
|
+
pairs: list[ContrastivePair] = []
|
|
48
|
+
|
|
49
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
50
|
+
|
|
51
|
+
for doc in docs:
|
|
52
|
+
pair = self._extract_pair_from_doc(doc)
|
|
53
|
+
if pair is not None:
|
|
54
|
+
pairs.append(pair)
|
|
55
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
56
|
+
break
|
|
57
|
+
|
|
58
|
+
if not pairs:
|
|
59
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
60
|
+
log.warning("No valid MC-TACO pairs extracted", extra={"task": task_name})
|
|
61
|
+
|
|
62
|
+
return pairs
|
|
63
|
+
|
|
64
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
65
|
+
"""
|
|
66
|
+
Convert a single MC-TACO doc into a ContrastivePair, if possible.
|
|
67
|
+
Returns None when required fields are missing or malformed.
|
|
68
|
+
"""
|
|
69
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
sentence = str(doc.get("sentence", "")).strip()
|
|
73
|
+
question = str(doc.get("question", "")).strip()
|
|
74
|
+
answer = str(doc.get("answer", "")).strip()
|
|
75
|
+
label = doc.get("label")
|
|
76
|
+
|
|
77
|
+
if not sentence or not question or not answer or label not in {0, 1}:
|
|
78
|
+
log.debug(
|
|
79
|
+
"Skipping doc due to missing/invalid fields",
|
|
80
|
+
extra={"doc": doc},
|
|
81
|
+
)
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
formatted_question = f"{sentence}\nQuestion: {question}\nAnswer: {answer}\nPlausible:\nA. Yes\nB. No"
|
|
85
|
+
|
|
86
|
+
correct = "Yes" if label == 1 else "No"
|
|
87
|
+
incorrect = "No" if label == 1 else "Yes"
|
|
88
|
+
|
|
89
|
+
metadata = {
|
|
90
|
+
"label": "mc-taco",
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
return self._build_pair(
|
|
94
|
+
question=formatted_question,
|
|
95
|
+
correct=correct,
|
|
96
|
+
incorrect=incorrect,
|
|
97
|
+
metadata=metadata,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
except Exception as exc:
|
|
101
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
102
|
+
return None
|
|
103
|
+
|
|
104
|
+
@staticmethod
|
|
105
|
+
def _build_pair(
|
|
106
|
+
question: str,
|
|
107
|
+
correct: str,
|
|
108
|
+
incorrect: str,
|
|
109
|
+
metadata: dict[str, Any] | None = None,
|
|
110
|
+
) -> ContrastivePair:
|
|
111
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
112
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
113
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
6
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
7
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
8
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from lm_eval.api.task import ConfigurableTask
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = ["MedQAExtractor"]
|
|
15
|
+
_LOG = setup_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MedQAExtractor(LMEvalBenchmarkExtractor):
|
|
19
|
+
"""Extractor for the MedQA benchmark."""
|
|
20
|
+
|
|
21
|
+
def extract_contrastive_pairs(
|
|
22
|
+
self,
|
|
23
|
+
lm_eval_task_data: ConfigurableTask,
|
|
24
|
+
limit: int | None = None,
|
|
25
|
+
) -> list[ContrastivePair]:
|
|
26
|
+
"""
|
|
27
|
+
Build contrastive pairs from MedQA docs.
|
|
28
|
+
|
|
29
|
+
MedQA schema:
|
|
30
|
+
- sent1: str
|
|
31
|
+
- ending0, ending1, ending2, ending3: str
|
|
32
|
+
- label: 0 or 1 or 2 or 3
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
lm_eval_task_data: lm-eval task instance for MedQA.
|
|
36
|
+
limit: Optional maximum number of pairs to produce.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
A list of ContrastivePair objects.
|
|
40
|
+
"""
|
|
41
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
42
|
+
|
|
43
|
+
max_items = self._normalize_limit(limit)
|
|
44
|
+
docs = self.load_docs(lm_eval_task_data, max_items)
|
|
45
|
+
|
|
46
|
+
pairs: list[ContrastivePair] = []
|
|
47
|
+
|
|
48
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
49
|
+
|
|
50
|
+
for doc in docs:
|
|
51
|
+
pair = self._extract_pair_from_doc(doc)
|
|
52
|
+
if pair is not None:
|
|
53
|
+
pairs.append(pair)
|
|
54
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
55
|
+
break
|
|
56
|
+
|
|
57
|
+
if not pairs:
|
|
58
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
59
|
+
log.warning("No valid MedQA pairs extracted", extra={"task": task_name})
|
|
60
|
+
|
|
61
|
+
return pairs
|
|
62
|
+
|
|
63
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
64
|
+
"""
|
|
65
|
+
Convert a single MedQA doc into a ContrastivePair, if possible.
|
|
66
|
+
Returns None when required fields are missing or malformed.
|
|
67
|
+
"""
|
|
68
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
69
|
+
|
|
70
|
+
try:
|
|
71
|
+
sent1 = str(doc.get("sent1", "")).strip()
|
|
72
|
+
endings = [str(doc.get("ending0", "")).strip(), str(doc.get("ending1", "")).strip(), str(doc.get("ending2", "")).strip(), str(doc.get("ending3", "")).strip()]
|
|
73
|
+
label = doc.get("label")
|
|
74
|
+
|
|
75
|
+
if not sent1 or not endings or label not in {0, 1, 2, 3}:
|
|
76
|
+
log.debug(
|
|
77
|
+
"Skipping doc due to missing/invalid fields",
|
|
78
|
+
extra={"doc": doc},
|
|
79
|
+
)
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
labels = {0: "entailment", 1: "neutral", 2: "contradiction"}
|
|
83
|
+
correct = endings[label]
|
|
84
|
+
incorrect = endings[(label+1)%3]
|
|
85
|
+
|
|
86
|
+
formatted_question = f"Question:{sent1}\nA. {incorrect}\nB. {correct}"
|
|
87
|
+
|
|
88
|
+
metadata = {
|
|
89
|
+
"label": "medqa",
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
return self._build_pair(
|
|
93
|
+
question=formatted_question,
|
|
94
|
+
correct=correct,
|
|
95
|
+
incorrect=incorrect,
|
|
96
|
+
metadata=metadata,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
except Exception as exc:
|
|
100
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
@staticmethod
|
|
104
|
+
def _build_pair(
|
|
105
|
+
question: str,
|
|
106
|
+
correct: str,
|
|
107
|
+
incorrect: str,
|
|
108
|
+
metadata: dict[str, Any] | None = None,
|
|
109
|
+
) -> ContrastivePair:
|
|
110
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
111
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
112
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
6
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
7
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
8
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from lm_eval.api.task import ConfigurableTask
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = ["MRPCExtractor"]
|
|
15
|
+
_LOG = setup_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MRPCExtractor(LMEvalBenchmarkExtractor):
|
|
19
|
+
"""Extractor for the MRPC benchmark."""
|
|
20
|
+
|
|
21
|
+
def extract_contrastive_pairs(
|
|
22
|
+
self,
|
|
23
|
+
lm_eval_task_data: ConfigurableTask,
|
|
24
|
+
limit: int | None = None,
|
|
25
|
+
) -> list[ContrastivePair]:
|
|
26
|
+
"""
|
|
27
|
+
Build contrastive pairs from MRPC docs.
|
|
28
|
+
|
|
29
|
+
MRPC schema:
|
|
30
|
+
- sentence1: str
|
|
31
|
+
- sentence2: str
|
|
32
|
+
- label: 0 or 1
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
lm_eval_task_data: lm-eval task instance for SciQ.
|
|
36
|
+
limit: Optional maximum number of pairs to produce.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
A list of ContrastivePair objects.
|
|
40
|
+
"""
|
|
41
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
42
|
+
|
|
43
|
+
max_items = self._normalize_limit(limit)
|
|
44
|
+
docs = self.load_docs(lm_eval_task_data, max_items)
|
|
45
|
+
|
|
46
|
+
pairs: list[ContrastivePair] = []
|
|
47
|
+
|
|
48
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
49
|
+
|
|
50
|
+
for doc in docs:
|
|
51
|
+
pair = self._extract_pair_from_doc(doc)
|
|
52
|
+
if pair is not None:
|
|
53
|
+
pairs.append(pair)
|
|
54
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
55
|
+
break
|
|
56
|
+
|
|
57
|
+
if not pairs:
|
|
58
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
59
|
+
log.warning("No valid MRPC pairs extracted", extra={"task": task_name})
|
|
60
|
+
|
|
61
|
+
return pairs
|
|
62
|
+
|
|
63
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
64
|
+
"""
|
|
65
|
+
Convert a single MRPC doc into a ContrastivePair, if possible.
|
|
66
|
+
Returns None when required fields are missing or malformed.
|
|
67
|
+
"""
|
|
68
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
69
|
+
|
|
70
|
+
try:
|
|
71
|
+
sentence1 = str(doc.get("sentence1", "")).strip()
|
|
72
|
+
sentence2 = str(doc.get("sentence2", "")).strip()
|
|
73
|
+
label = doc.get("label")
|
|
74
|
+
|
|
75
|
+
if not sentence1 or not sentence2 or label not in {0, 1}:
|
|
76
|
+
log.debug(
|
|
77
|
+
"Skipping doc due to missing/invalid fields",
|
|
78
|
+
extra={"doc": doc},
|
|
79
|
+
)
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
formatted_question = f"Sentence 1: {sentence1}\nSentence 2: {sentence2}. Do both sequences mean the same thing?\nAnswer:\nA. Yes\nB. No"
|
|
83
|
+
|
|
84
|
+
correct = "Yes" if label == 1 else "No"
|
|
85
|
+
incorrect = "No" if label == 1 else "Yes"
|
|
86
|
+
|
|
87
|
+
metadata = {
|
|
88
|
+
"label": "mrpc",
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
return self._build_pair(
|
|
92
|
+
question=formatted_question,
|
|
93
|
+
correct=correct,
|
|
94
|
+
incorrect=incorrect,
|
|
95
|
+
metadata=metadata,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
except Exception as exc:
|
|
99
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
@staticmethod
|
|
103
|
+
def _build_pair(
|
|
104
|
+
question: str,
|
|
105
|
+
correct: str,
|
|
106
|
+
incorrect: str,
|
|
107
|
+
metadata: dict[str, Any] | None = None,
|
|
108
|
+
) -> ContrastivePair:
|
|
109
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
110
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
111
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
from typing import Any, TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
7
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
8
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
9
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from lm_eval.api.task import ConfigurableTask
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
__all__ = ["MultiRCExtractor"]
|
|
16
|
+
_LOG = setup_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MultiRCExtractor(LMEvalBenchmarkExtractor):
|
|
20
|
+
"""Extractor for the MultiRC benchmark."""
|
|
21
|
+
|
|
22
|
+
def extract_contrastive_pairs(
|
|
23
|
+
self,
|
|
24
|
+
lm_eval_task_data: ConfigurableTask,
|
|
25
|
+
limit: int | None = None,
|
|
26
|
+
) -> list[ContrastivePair]:
|
|
27
|
+
"""
|
|
28
|
+
Build contrastive pairs from MultiRC docs.
|
|
29
|
+
|
|
30
|
+
MultiRC schema:
|
|
31
|
+
- paragraph: str
|
|
32
|
+
- question: str
|
|
33
|
+
- answer: dict
|
|
34
|
+
- label: 0 or 1
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
lm_eval_task_data: lm-eval task instance for MultiRC.
|
|
38
|
+
limit: Optional maximum number of pairs to produce.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
A list of ContrastivePair objects.
|
|
42
|
+
"""
|
|
43
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
44
|
+
|
|
45
|
+
max_items = self._normalize_limit(limit)
|
|
46
|
+
docs = self.load_docs(lm_eval_task_data, max_items)
|
|
47
|
+
|
|
48
|
+
pairs: list[ContrastivePair] = []
|
|
49
|
+
|
|
50
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
51
|
+
|
|
52
|
+
for doc in docs:
|
|
53
|
+
pair = self._extract_pair_from_doc(doc)
|
|
54
|
+
if pair is not None:
|
|
55
|
+
pairs.append(pair)
|
|
56
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
57
|
+
break
|
|
58
|
+
|
|
59
|
+
if not pairs:
|
|
60
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
61
|
+
log.warning("No valid MultiRC pairs extracted", extra={"task": task_name})
|
|
62
|
+
|
|
63
|
+
return pairs
|
|
64
|
+
|
|
65
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
66
|
+
"""
|
|
67
|
+
Convert a single MultiRC doc into a ContrastivePair, if possible.
|
|
68
|
+
Returns None when required fields are missing or malformed.
|
|
69
|
+
"""
|
|
70
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
paragraph = str(doc.get("paragraph", "")).strip()
|
|
74
|
+
question = str(doc.get("question")).strip()
|
|
75
|
+
answer = str(doc.get("answer")).strip()
|
|
76
|
+
label = doc.get("label")
|
|
77
|
+
|
|
78
|
+
if not paragraph or not question or not answer or label not in {0, 1}:
|
|
79
|
+
log.debug(
|
|
80
|
+
"Skipping doc due to missing/invalid fields",
|
|
81
|
+
extra={"doc": doc},
|
|
82
|
+
)
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
formatted_question = f"{paragraph}\nQuestion: {question}\nAnswer: {answer}\nIs this answer correct?\nA. Yes\nB. No"
|
|
86
|
+
|
|
87
|
+
correct = "Yes" if label == 1 else "No"
|
|
88
|
+
incorrect = "No" if label == 1 else "Yes"
|
|
89
|
+
|
|
90
|
+
metadata = {
|
|
91
|
+
"label": "multirc",
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
return self._build_pair(
|
|
95
|
+
question=formatted_question,
|
|
96
|
+
correct=correct,
|
|
97
|
+
incorrect=incorrect,
|
|
98
|
+
metadata=metadata,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
except Exception as exc:
|
|
102
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
@staticmethod
|
|
106
|
+
def _build_pair(
|
|
107
|
+
question: str,
|
|
108
|
+
correct: str,
|
|
109
|
+
incorrect: str,
|
|
110
|
+
metadata: dict[str, Any] | None = None,
|
|
111
|
+
) -> ContrastivePair:
|
|
112
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
113
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
114
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|