wisent 0.5.12__py3-none-any.whl → 0.5.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -1
- wisent/core/activations/__init__.py +26 -0
- wisent/core/activations/activations.py +96 -0
- wisent/core/activations/activations_collector.py +71 -20
- wisent/core/activations/prompt_construction_strategy.py +47 -0
- wisent/core/agent/budget.py +2 -2
- wisent/core/agent/device_benchmarks.py +1 -1
- wisent/core/agent/diagnose/classifier_marketplace.py +8 -8
- wisent/core/agent/diagnose/response_diagnostics.py +4 -4
- wisent/core/agent/diagnose/synthetic_classifier_option.py +1 -1
- wisent/core/agent/diagnose/tasks/task_manager.py +3 -3
- wisent/core/agent/diagnose.py +2 -1
- wisent/core/autonomous_agent.py +10 -2
- wisent/core/benchmark_extractors.py +293 -0
- wisent/core/bigcode_integration.py +20 -7
- wisent/core/branding.py +108 -0
- wisent/core/cli/__init__.py +15 -0
- wisent/core/cli/create_steering_vector.py +138 -0
- wisent/core/cli/evaluate_responses.py +715 -0
- wisent/core/cli/generate_pairs.py +128 -0
- wisent/core/cli/generate_pairs_from_task.py +119 -0
- wisent/core/cli/generate_responses.py +129 -0
- wisent/core/cli/generate_vector_from_synthetic.py +149 -0
- wisent/core/cli/generate_vector_from_task.py +147 -0
- wisent/core/cli/get_activations.py +191 -0
- wisent/core/cli/optimize_classification.py +339 -0
- wisent/core/cli/optimize_steering.py +364 -0
- wisent/core/cli/tasks.py +182 -0
- wisent/core/cli_logger.py +22 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +27 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +49 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +118 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +146 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +129 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gsm8k.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/livecodebench.py +367 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +121 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +121 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +110 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/squad2.py +124 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sst2.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +127 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_gen.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +117 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +117 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +127 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +1 -1
- wisent/core/data_loaders/__init__.py +235 -0
- wisent/core/data_loaders/loaders/lm_loader.py +2 -2
- wisent/core/data_loaders/loaders/task_interface_loader.py +300 -0
- wisent/{cli/data_loaders/data_loader_rotator.py → core/data_loaders/rotator.py} +1 -1
- wisent/core/download_full_benchmarks.py +79 -2
- wisent/core/evaluators/benchmark_specific/__init__.py +26 -0
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/evaluator.py +17 -17
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/cpp_sanitizer.py +2 -2
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/java_sanitizer.py +2 -2
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/python_sanitizer.py +2 -2
- wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/__init__.py +3 -0
- wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/provider.py +305 -0
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/runtime.py +36 -4
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/entrypoint.py +2 -4
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/recipes.py +1 -1
- wisent/core/evaluators/benchmark_specific/coding/solution_generator.py +258 -0
- wisent/core/evaluators/benchmark_specific/exact_match_evaluator.py +79 -0
- wisent/core/evaluators/benchmark_specific/f1_evaluator.py +101 -0
- wisent/core/evaluators/benchmark_specific/generation_evaluator.py +197 -0
- wisent/core/{log_likelihoods_evaluator.py → evaluators/benchmark_specific/log_likelihoods_evaluator.py} +10 -2
- wisent/core/evaluators/benchmark_specific/perplexity_evaluator.py +140 -0
- wisent/core/evaluators/benchmark_specific/personalization_evaluator.py +250 -0
- wisent/{cli/evaluators/evaluator_rotator.py → core/evaluators/rotator.py} +4 -4
- wisent/core/lm_eval_harness_ground_truth.py +3 -2
- wisent/core/main.py +57 -0
- wisent/core/model_persistence.py +2 -2
- wisent/core/models/wisent_model.py +6 -6
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
- wisent/core/optuna/steering/steering_optimization.py +1 -1
- wisent/core/parser_arguments/__init__.py +10 -0
- wisent/core/parser_arguments/agent_parser.py +110 -0
- wisent/core/parser_arguments/configure_model_parser.py +7 -0
- wisent/core/parser_arguments/create_steering_vector_parser.py +59 -0
- wisent/core/parser_arguments/evaluate_parser.py +40 -0
- wisent/core/parser_arguments/evaluate_responses_parser.py +10 -0
- wisent/core/parser_arguments/full_optimize_parser.py +115 -0
- wisent/core/parser_arguments/generate_pairs_from_task_parser.py +33 -0
- wisent/core/parser_arguments/generate_pairs_parser.py +29 -0
- wisent/core/parser_arguments/generate_responses_parser.py +15 -0
- wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +127 -0
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +127 -0
- wisent/core/parser_arguments/generate_vector_parser.py +90 -0
- wisent/core/parser_arguments/get_activations_parser.py +90 -0
- wisent/core/parser_arguments/main_parser.py +152 -0
- wisent/core/parser_arguments/model_config_parser.py +59 -0
- wisent/core/parser_arguments/monitor_parser.py +17 -0
- wisent/core/parser_arguments/multi_steer_parser.py +47 -0
- wisent/core/parser_arguments/optimize_classification_parser.py +67 -0
- wisent/core/parser_arguments/optimize_sample_size_parser.py +58 -0
- wisent/core/parser_arguments/optimize_steering_parser.py +147 -0
- wisent/core/parser_arguments/synthetic_parser.py +93 -0
- wisent/core/parser_arguments/tasks_parser.py +584 -0
- wisent/core/parser_arguments/test_nonsense_parser.py +26 -0
- wisent/core/parser_arguments/utils.py +111 -0
- wisent/core/prompts/core/prompt_formater.py +3 -3
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +2 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +2 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +2 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +2 -0
- wisent/{cli/steering_methods/steering_rotator.py → core/steering_methods/rotator.py} +4 -4
- wisent/core/steering_optimizer.py +45 -21
- wisent/{synthetic → core/synthetic}/cleaners/deduper_cleaner.py +3 -3
- wisent/{synthetic → core/synthetic}/cleaners/methods/base_dedupers.py +2 -2
- wisent/{synthetic → core/synthetic}/cleaners/methods/base_refusalers.py +1 -1
- wisent/{synthetic → core/synthetic}/cleaners/pairs_cleaner.py +5 -5
- wisent/{synthetic → core/synthetic}/cleaners/refusaler_cleaner.py +4 -4
- wisent/{synthetic → core/synthetic}/db_instructions/mini_dp.py +1 -1
- wisent/{synthetic → core/synthetic}/generators/diversities/methods/fast_diversity.py +1 -1
- wisent/{synthetic → core/synthetic}/generators/pairs_generator.py +38 -12
- wisent/core/tasks/livecodebench_task.py +4 -103
- wisent/core/timing_calibration.py +1 -1
- {wisent-0.5.12.dist-info → wisent-0.5.13.dist-info}/METADATA +3 -3
- wisent-0.5.13.dist-info/RECORD +294 -0
- wisent-0.5.13.dist-info/entry_points.txt +2 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +0 -53
- wisent/classifiers/core/atoms.py +0 -747
- wisent/classifiers/models/logistic.py +0 -29
- wisent/classifiers/models/mlp.py +0 -47
- wisent/cli/classifiers/classifier_rotator.py +0 -137
- wisent/cli/cli_logger.py +0 -142
- wisent/cli/wisent_cli/commands/help_cmd.py +0 -52
- wisent/cli/wisent_cli/commands/listing.py +0 -154
- wisent/cli/wisent_cli/commands/train_cmd.py +0 -322
- wisent/cli/wisent_cli/main.py +0 -93
- wisent/cli/wisent_cli/shell.py +0 -80
- wisent/cli/wisent_cli/ui.py +0 -69
- wisent/cli/wisent_cli/util/aggregations.py +0 -43
- wisent/cli/wisent_cli/util/parsing.py +0 -126
- wisent/cli/wisent_cli/version.py +0 -4
- wisent/opti/methods/__init__.py +0 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent-0.5.12.dist-info/RECORD +0 -220
- /wisent/{benchmarks → core/evaluators/benchmark_specific/coding}/__init__.py +0 -0
- /wisent/{benchmarks/coding → core/evaluators/benchmark_specific/coding/metrics}/__init__.py +0 -0
- /wisent/{benchmarks/coding/metrics → core/evaluators/benchmark_specific/coding/metrics/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/core/atoms.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/passk.py +0 -0
- /wisent/{benchmarks/coding/metrics/core → core/evaluators/benchmark_specific/coding/output_sanitizer}/__init__.py +0 -0
- /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/output_sanitizer/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/core/atoms.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/utils.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/__init__.py +0 -0
- /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/providers}/core/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/core/atoms.py +0 -0
- /wisent/{benchmarks/coding/providers/core → core/evaluators/benchmark_specific/coding/safe_docker}/__init__.py +0 -0
- /wisent/{benchmarks/coding/providers/livecodebench → core/evaluators/benchmark_specific/coding/safe_docker/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/atoms.py +0 -0
- /wisent/{benchmarks/coding/safe_docker → core/opti}/__init__.py +0 -0
- /wisent/{benchmarks/coding/safe_docker → core/opti}/core/__init__.py +0 -0
- /wisent/{opti → core/opti}/core/atoms.py +0 -0
- /wisent/{classifiers → core/opti/methods}/__init__.py +0 -0
- /wisent/{opti → core/opti}/methods/opti_classificator.py +0 -0
- /wisent/{opti → core/opti}/methods/opti_steering.py +0 -0
- /wisent/{classifiers/core → core/synthetic}/__init__.py +0 -0
- /wisent/{classifiers/models → core/synthetic/cleaners}/__init__.py +0 -0
- /wisent/{cli → core/synthetic/cleaners/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/cleaners/core/atoms.py +0 -0
- /wisent/{cli/classifiers → core/synthetic/cleaners/methods}/__init__.py +0 -0
- /wisent/{cli/data_loaders → core/synthetic/cleaners/methods/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/cleaners/methods/core/atoms.py +0 -0
- /wisent/{cli/evaluators → core/synthetic/db_instructions}/__init__.py +0 -0
- /wisent/{cli/steering_methods → core/synthetic/db_instructions/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/db_instructions/core/atoms.py +0 -0
- /wisent/{cli/wisent_cli → core/synthetic/generators}/__init__.py +0 -0
- /wisent/{cli/wisent_cli/commands → core/synthetic/generators/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/generators/core/atoms.py +0 -0
- /wisent/{cli/wisent_cli/util → core/synthetic/generators/diversities}/__init__.py +0 -0
- /wisent/{opti → core/synthetic/generators/diversities/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/generators/diversities/core/core.py +0 -0
- /wisent/{opti/core → core/synthetic/generators/diversities/methods}/__init__.py +0 -0
- {wisent-0.5.12.dist-info → wisent-0.5.13.dist-info}/WHEEL +0 -0
- {wisent-0.5.12.dist-info → wisent-0.5.13.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.5.12.dist-info → wisent-0.5.13.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
6
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
7
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
8
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from lm_eval.api.task import ConfigurableTask
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = ["RTEExtractor"]
|
|
15
|
+
_LOG = setup_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RTEExtractor(LMEvalBenchmarkExtractor):
|
|
19
|
+
"""Extractor for the RTE benchmark."""
|
|
20
|
+
|
|
21
|
+
def extract_contrastive_pairs(
|
|
22
|
+
self,
|
|
23
|
+
lm_eval_task_data: ConfigurableTask,
|
|
24
|
+
limit: int | None = None,
|
|
25
|
+
) -> list[ContrastivePair]:
|
|
26
|
+
"""
|
|
27
|
+
Build contrastive pairs from RTE docs.
|
|
28
|
+
|
|
29
|
+
RTE schema:
|
|
30
|
+
- sentence1: str
|
|
31
|
+
- sentence2: str
|
|
32
|
+
- label: 0 or 1
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
lm_eval_task_data: lm-eval task instance for RTE.
|
|
36
|
+
limit: Optional maximum number of pairs to produce.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
A list of ContrastivePair objects.
|
|
40
|
+
"""
|
|
41
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
42
|
+
|
|
43
|
+
max_items = self._normalize_limit(limit)
|
|
44
|
+
docs = self.load_docs(lm_eval_task_data, max_items)
|
|
45
|
+
|
|
46
|
+
pairs: list[ContrastivePair] = []
|
|
47
|
+
|
|
48
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
49
|
+
|
|
50
|
+
for doc in docs:
|
|
51
|
+
pair = self._extract_pair_from_doc(doc)
|
|
52
|
+
if pair is not None:
|
|
53
|
+
pairs.append(pair)
|
|
54
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
55
|
+
break
|
|
56
|
+
|
|
57
|
+
if not pairs:
|
|
58
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
59
|
+
log.warning("No valid QNLI pairs extracted", extra={"task": task_name})
|
|
60
|
+
|
|
61
|
+
return pairs
|
|
62
|
+
|
|
63
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
64
|
+
"""
|
|
65
|
+
Convert a single RTE doc into a ContrastivePair, if possible.
|
|
66
|
+
Returns None when required fields are missing or malformed.
|
|
67
|
+
"""
|
|
68
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
69
|
+
|
|
70
|
+
try:
|
|
71
|
+
sentence1 = str(doc.get("sentence1", "")).strip()
|
|
72
|
+
sentence2 = str(doc.get("sentence2", "")).strip()
|
|
73
|
+
label = doc.get("label")
|
|
74
|
+
|
|
75
|
+
if not sentence1 or not sentence2 or label not in {0, 1}:
|
|
76
|
+
log.debug(
|
|
77
|
+
"Skipping doc due to missing/invalid fields",
|
|
78
|
+
extra={"doc": doc},
|
|
79
|
+
)
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
formatted_question = f"{sentence1}\nQuestion: {sentence2} True or False?\nAnswer:\nA. True\nB. False"
|
|
83
|
+
|
|
84
|
+
correct = "True" if label == 0 else "False"
|
|
85
|
+
incorrect = "False" if label == 0 else "True"
|
|
86
|
+
|
|
87
|
+
metadata = {
|
|
88
|
+
"label": "qqp",
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
return self._build_pair(
|
|
92
|
+
question=formatted_question,
|
|
93
|
+
correct=correct,
|
|
94
|
+
incorrect=incorrect,
|
|
95
|
+
metadata=metadata,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
except Exception as exc:
|
|
99
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
@staticmethod
|
|
103
|
+
def _build_pair(
|
|
104
|
+
question: str,
|
|
105
|
+
correct: str,
|
|
106
|
+
incorrect: str,
|
|
107
|
+
metadata: dict[str, Any] | None = None,
|
|
108
|
+
) -> ContrastivePair:
|
|
109
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
110
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
111
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
6
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
7
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
8
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from lm_eval.api.task import ConfigurableTask
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = ["SciQExtractor"]
|
|
15
|
+
_LOG = setup_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SciQExtractor(LMEvalBenchmarkExtractor):
|
|
19
|
+
"""Extractor for the SciQ benchmark."""
|
|
20
|
+
|
|
21
|
+
def extract_contrastive_pairs(
|
|
22
|
+
self,
|
|
23
|
+
lm_eval_task_data: ConfigurableTask,
|
|
24
|
+
limit: int | None = None,
|
|
25
|
+
) -> list[ContrastivePair]:
|
|
26
|
+
"""
|
|
27
|
+
Build contrastive pairs from SciQ docs.
|
|
28
|
+
|
|
29
|
+
SciQ schema:
|
|
30
|
+
- support: str
|
|
31
|
+
- question: str
|
|
32
|
+
- distractor1, distractor2, distractor3: str
|
|
33
|
+
- correct_answer: str
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
lm_eval_task_data: lm-eval task instance for SciQ.
|
|
37
|
+
limit: Optional maximum number of pairs to produce.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A list of ContrastivePair objects.
|
|
41
|
+
"""
|
|
42
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
43
|
+
|
|
44
|
+
max_items = self._normalize_limit(limit)
|
|
45
|
+
docs = self.load_docs(lm_eval_task_data, max_items)
|
|
46
|
+
|
|
47
|
+
pairs: list[ContrastivePair] = []
|
|
48
|
+
|
|
49
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
50
|
+
|
|
51
|
+
for doc in docs:
|
|
52
|
+
pair = self._extract_pair_from_doc(doc)
|
|
53
|
+
if pair is not None:
|
|
54
|
+
pairs.append(pair)
|
|
55
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
56
|
+
break
|
|
57
|
+
|
|
58
|
+
if not pairs:
|
|
59
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
60
|
+
log.warning("No valid SciQ pairs extracted", extra={"task": task_name})
|
|
61
|
+
|
|
62
|
+
return pairs
|
|
63
|
+
|
|
64
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
65
|
+
"""
|
|
66
|
+
Convert a single SciQ doc into a ContrastivePair, if possible.
|
|
67
|
+
Returns None when required fields are missing or malformed.
|
|
68
|
+
"""
|
|
69
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
support = str(doc.get("support", "")).strip()
|
|
73
|
+
question = str(doc.get("question", "")).strip()
|
|
74
|
+
correct = str(doc.get("correct_answer", "")).strip()
|
|
75
|
+
incorrect = str(doc.get("distractor1", "")).strip() #take any distractor
|
|
76
|
+
|
|
77
|
+
if not support or not question or not correct or not incorrect:
|
|
78
|
+
log.debug(
|
|
79
|
+
"Skipping doc due to missing/invalid fields",
|
|
80
|
+
extra={"doc": doc},
|
|
81
|
+
)
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
formatted_question = f"{support}\nQuestion: {question}\nAnswer:\nA. {incorrect}\nB. {correct}"
|
|
85
|
+
|
|
86
|
+
metadata = {
|
|
87
|
+
"label": "sciq",
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
return self._build_pair(
|
|
91
|
+
question=formatted_question,
|
|
92
|
+
correct=correct,
|
|
93
|
+
incorrect=incorrect,
|
|
94
|
+
metadata=metadata,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
except Exception as exc:
|
|
98
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
99
|
+
return None
|
|
100
|
+
|
|
101
|
+
@staticmethod
|
|
102
|
+
def _build_pair(
|
|
103
|
+
question: str,
|
|
104
|
+
correct: str,
|
|
105
|
+
incorrect: str,
|
|
106
|
+
metadata: dict[str, Any] | None = None,
|
|
107
|
+
) -> ContrastivePair:
|
|
108
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
109
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
110
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
6
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
7
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
8
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from lm_eval.api.task import ConfigurableTask
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = ["Social_IQAExtractor"]
|
|
15
|
+
_LOG = setup_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Social_IQAExtractor(LMEvalBenchmarkExtractor):
|
|
19
|
+
"""Extractor for the Social_IQA benchmark."""
|
|
20
|
+
|
|
21
|
+
def extract_contrastive_pairs(
|
|
22
|
+
self,
|
|
23
|
+
lm_eval_task_data: ConfigurableTask,
|
|
24
|
+
limit: int | None = None,
|
|
25
|
+
) -> list[ContrastivePair]:
|
|
26
|
+
"""
|
|
27
|
+
Build contrastive pairs from Social_IQA docs.
|
|
28
|
+
|
|
29
|
+
Social_IQA schema:
|
|
30
|
+
- context: str
|
|
31
|
+
- question: str
|
|
32
|
+
- answerA, answerB, answerC: str
|
|
33
|
+
- label: "1" or "2" or "3"
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
lm_eval_task_data: lm-eval task instance for Social_IQA.
|
|
37
|
+
limit: Optional maximum number of pairs to produce.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A list of ContrastivePair objects.
|
|
41
|
+
"""
|
|
42
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
43
|
+
|
|
44
|
+
max_items = self._normalize_limit(limit)
|
|
45
|
+
docs = self.load_docs(lm_eval_task_data, max_items)
|
|
46
|
+
|
|
47
|
+
pairs: list[ContrastivePair] = []
|
|
48
|
+
|
|
49
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
50
|
+
|
|
51
|
+
for doc in docs:
|
|
52
|
+
pair = self._extract_pair_from_doc(doc)
|
|
53
|
+
if pair is not None:
|
|
54
|
+
pairs.append(pair)
|
|
55
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
56
|
+
break
|
|
57
|
+
|
|
58
|
+
if not pairs:
|
|
59
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
60
|
+
log.warning("No valid Social_IQA pairs extracted", extra={"task": task_name})
|
|
61
|
+
|
|
62
|
+
return pairs
|
|
63
|
+
|
|
64
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
65
|
+
"""
|
|
66
|
+
Convert a single Social_IQA doc into a ContrastivePair, if possible.
|
|
67
|
+
Returns None when required fields are missing or malformed.
|
|
68
|
+
"""
|
|
69
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
context= str(doc.get("context", "")).strip()
|
|
73
|
+
question = str(doc.get("question", "")).strip()
|
|
74
|
+
answers = [str(doc.get("answerA", "")).strip(), str(doc.get("answerB", "")).strip(), str(doc.get("answerC", "")).strip()]
|
|
75
|
+
label = str(doc.get("label", "")).strip()
|
|
76
|
+
label = int(label) - 1
|
|
77
|
+
|
|
78
|
+
if not context or not question or not answers or label not in {0, 1, 2}:
|
|
79
|
+
log.debug(
|
|
80
|
+
"Skipping doc due to missing/invalid fields",
|
|
81
|
+
extra={"doc": doc},
|
|
82
|
+
)
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
correct = answers[label]
|
|
86
|
+
incorrect = answers[(label+1)%len(answers)]
|
|
87
|
+
|
|
88
|
+
formatted_question = f"Q: {context} {question}\nA:\nA. {incorrect}\nB. {correct}"
|
|
89
|
+
|
|
90
|
+
metadata = {
|
|
91
|
+
"label": "social_iqa",
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
return self._build_pair(
|
|
95
|
+
question=formatted_question,
|
|
96
|
+
correct=correct,
|
|
97
|
+
incorrect=incorrect,
|
|
98
|
+
metadata=metadata,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
except Exception as exc:
|
|
102
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
@staticmethod
|
|
106
|
+
def _build_pair(
|
|
107
|
+
question: str,
|
|
108
|
+
correct: str,
|
|
109
|
+
incorrect: str,
|
|
110
|
+
metadata: dict[str, Any] | None = None,
|
|
111
|
+
) -> ContrastivePair:
|
|
112
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
113
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
114
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
from typing import Any, TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
7
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
8
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
9
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from lm_eval.api.task import ConfigurableTask
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
__all__ = ["SQuAD2Extractor"]
|
|
16
|
+
_LOG = setup_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SQuAD2Extractor(LMEvalBenchmarkExtractor):
|
|
20
|
+
"""Extractor for the SQuAD2 benchmark."""
|
|
21
|
+
|
|
22
|
+
def extract_contrastive_pairs(
|
|
23
|
+
self,
|
|
24
|
+
lm_eval_task_data: ConfigurableTask,
|
|
25
|
+
limit: int | None = None,
|
|
26
|
+
) -> list[ContrastivePair]:
|
|
27
|
+
"""
|
|
28
|
+
Build contrastive pairs from SQuAD2 docs.
|
|
29
|
+
|
|
30
|
+
SQuAD2 schema:
|
|
31
|
+
- context: str
|
|
32
|
+
- question: str
|
|
33
|
+
- answers: list
|
|
34
|
+
- title: str
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
lm_eval_task_data: lm-eval task instance for SQuAD2.
|
|
38
|
+
limit: Optional maximum number of pairs to produce.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
A list of ContrastivePair objects.
|
|
42
|
+
"""
|
|
43
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
44
|
+
|
|
45
|
+
max_items = self._normalize_limit(limit)
|
|
46
|
+
docs = self.load_docs(lm_eval_task_data, max_items)
|
|
47
|
+
|
|
48
|
+
pairs: list[ContrastivePair] = []
|
|
49
|
+
|
|
50
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
51
|
+
|
|
52
|
+
for doc in docs:
|
|
53
|
+
pair = self._extract_pair_from_doc(doc)
|
|
54
|
+
if pair is not None:
|
|
55
|
+
pairs.append(pair)
|
|
56
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
57
|
+
break
|
|
58
|
+
|
|
59
|
+
if not pairs:
|
|
60
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
61
|
+
log.warning("No valid SQuAD2 pairs extracted", extra={"task": task_name})
|
|
62
|
+
|
|
63
|
+
return pairs
|
|
64
|
+
|
|
65
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
66
|
+
"""
|
|
67
|
+
Convert a single SQuAD2 doc into a ContrastivePair, if possible.
|
|
68
|
+
Returns None when required fields are missing or malformed.
|
|
69
|
+
"""
|
|
70
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
context = str(doc.get("context", "")).strip()
|
|
74
|
+
question = str(doc.get("question", "")).strip()
|
|
75
|
+
answers = doc.get("answers", [])
|
|
76
|
+
answers = answers["text"]
|
|
77
|
+
|
|
78
|
+
if not context or not question:
|
|
79
|
+
log.debug(
|
|
80
|
+
"Skipping doc due to missing/invalid fields",
|
|
81
|
+
extra={"doc": doc},
|
|
82
|
+
)
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
correct = answers[0] if answers else "No answer"
|
|
86
|
+
if correct == "No answer":
|
|
87
|
+
incorrect = "The answer is clearly stated in the background"
|
|
88
|
+
else:
|
|
89
|
+
# Create plausible but incorrect answers
|
|
90
|
+
incorrect_answers = [
|
|
91
|
+
"The information is not provided in the background.",
|
|
92
|
+
"This cannot be determined from the background.",
|
|
93
|
+
"The background does not contain this information.",
|
|
94
|
+
]
|
|
95
|
+
incorrect = random.choice(incorrect_answers)
|
|
96
|
+
|
|
97
|
+
title = doc.get("title")
|
|
98
|
+
formatted_question = f"Title: {title}\n\nBackground: {context}\n\nQuestion: {question}\n\nAnswer:\nA. {incorrect}\nB. {correct}"
|
|
99
|
+
|
|
100
|
+
metadata = {
|
|
101
|
+
"label": "record",
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
return self._build_pair(
|
|
105
|
+
question=formatted_question,
|
|
106
|
+
correct=correct,
|
|
107
|
+
incorrect=incorrect,
|
|
108
|
+
metadata=metadata,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
except Exception as exc:
|
|
112
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
@staticmethod
|
|
116
|
+
def _build_pair(
|
|
117
|
+
question: str,
|
|
118
|
+
correct: str,
|
|
119
|
+
incorrect: str,
|
|
120
|
+
metadata: dict[str, Any] | None = None,
|
|
121
|
+
) -> ContrastivePair:
|
|
122
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
123
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
124
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
6
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
7
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
8
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from lm_eval.api.task import ConfigurableTask
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = ["SST2Extractor"]
|
|
15
|
+
_LOG = setup_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SST2Extractor(LMEvalBenchmarkExtractor):
|
|
19
|
+
"""Extractor for the SST2 benchmark."""
|
|
20
|
+
|
|
21
|
+
def extract_contrastive_pairs(
|
|
22
|
+
self,
|
|
23
|
+
lm_eval_task_data: ConfigurableTask,
|
|
24
|
+
limit: int | None = None,
|
|
25
|
+
preferred_doc: str | None = None,
|
|
26
|
+
) -> list[ContrastivePair]:
|
|
27
|
+
"""
|
|
28
|
+
Build contrastive pairs from SST2 docs.
|
|
29
|
+
|
|
30
|
+
SST2 schema:
|
|
31
|
+
- sentence: str
|
|
32
|
+
- label: 0 or 1
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
lm_eval_task_data: lm-eval task instance for SST2.
|
|
36
|
+
limit: Optional maximum number of pairs to produce.
|
|
37
|
+
preferred_doc: Preferred document source ("validation", "test", "training", "fewshot")
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A list of ContrastivePair objects.
|
|
41
|
+
"""
|
|
42
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
43
|
+
|
|
44
|
+
max_items = self._normalize_limit(limit)
|
|
45
|
+
docs = self.load_docs(lm_eval_task_data, max_items, preferred_doc=preferred_doc)
|
|
46
|
+
|
|
47
|
+
pairs: list[ContrastivePair] = []
|
|
48
|
+
|
|
49
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
50
|
+
|
|
51
|
+
for doc in docs:
|
|
52
|
+
pair = self._extract_pair_from_doc(doc)
|
|
53
|
+
if pair is not None:
|
|
54
|
+
pairs.append(pair)
|
|
55
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
56
|
+
break
|
|
57
|
+
|
|
58
|
+
if not pairs:
|
|
59
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
60
|
+
log.warning("No valid SST2 pairs extracted", extra={"task": task_name})
|
|
61
|
+
|
|
62
|
+
return pairs
|
|
63
|
+
|
|
64
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
65
|
+
"""
|
|
66
|
+
Convert a single SST2 doc into a ContrastivePair, if possible.
|
|
67
|
+
Returns None when required fields are missing or malformed.
|
|
68
|
+
"""
|
|
69
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
sentence1 = str(doc.get("sentence", "")).strip()
|
|
73
|
+
label = doc.get("label")
|
|
74
|
+
|
|
75
|
+
if not sentence1 or label not in {0, 1}:
|
|
76
|
+
log.debug(
|
|
77
|
+
"Skipping doc due to missing/invalid fields",
|
|
78
|
+
extra={"doc": doc},
|
|
79
|
+
)
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
formatted_question = f"{sentence1}\nQuestion: Is this sentence positive or negative?"
|
|
83
|
+
|
|
84
|
+
correct = "Positive" if label == 1 else "Negative"
|
|
85
|
+
incorrect = "Negative" if label == 1 else "Positive"
|
|
86
|
+
|
|
87
|
+
metadata = {
|
|
88
|
+
"label": "sst2",
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
return self._build_pair(
|
|
92
|
+
question=formatted_question,
|
|
93
|
+
correct=correct,
|
|
94
|
+
incorrect=incorrect,
|
|
95
|
+
metadata=metadata,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
except Exception as exc:
|
|
99
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
@staticmethod
|
|
103
|
+
def _build_pair(
|
|
104
|
+
question: str,
|
|
105
|
+
correct: str,
|
|
106
|
+
incorrect: str,
|
|
107
|
+
metadata: dict[str, Any] | None = None,
|
|
108
|
+
) -> ContrastivePair:
|
|
109
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
110
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
111
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
6
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
7
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
8
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from lm_eval.api.task import ConfigurableTask
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = ["SWAGExtractor"]
|
|
15
|
+
_LOG = setup_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SWAGExtractor(LMEvalBenchmarkExtractor):
|
|
19
|
+
"""Extractor for the SWAG benchmark."""
|
|
20
|
+
|
|
21
|
+
def extract_contrastive_pairs(
|
|
22
|
+
self,
|
|
23
|
+
lm_eval_task_data: ConfigurableTask,
|
|
24
|
+
limit: int | None = None,
|
|
25
|
+
) -> list[ContrastivePair]:
|
|
26
|
+
"""
|
|
27
|
+
Build contrastive pairs from SWAG docs.
|
|
28
|
+
|
|
29
|
+
SWAG schema:
|
|
30
|
+
- startphrase: str
|
|
31
|
+
- ending0, ending1, ending2, ending3: str
|
|
32
|
+
- label: 0 or 1 or 2 or 3
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
lm_eval_task_data: lm-eval task instance for SWAG.
|
|
36
|
+
limit: Optional maximum number of pairs to produce.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
A list of ContrastivePair objects.
|
|
40
|
+
"""
|
|
41
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
42
|
+
|
|
43
|
+
max_items = self._normalize_limit(limit)
|
|
44
|
+
docs = self.load_docs(lm_eval_task_data, max_items)
|
|
45
|
+
|
|
46
|
+
pairs: list[ContrastivePair] = []
|
|
47
|
+
|
|
48
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
49
|
+
|
|
50
|
+
for doc in docs:
|
|
51
|
+
pair = self._extract_pair_from_doc(doc)
|
|
52
|
+
if pair is not None:
|
|
53
|
+
pairs.append(pair)
|
|
54
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
55
|
+
break
|
|
56
|
+
|
|
57
|
+
if not pairs:
|
|
58
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
59
|
+
log.warning("No valid SWAG pairs extracted", extra={"task": task_name})
|
|
60
|
+
|
|
61
|
+
return pairs
|
|
62
|
+
|
|
63
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
64
|
+
"""
|
|
65
|
+
Convert a single SWAG doc into a ContrastivePair, if possible.
|
|
66
|
+
Returns None when required fields are missing or malformed.
|
|
67
|
+
"""
|
|
68
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
69
|
+
|
|
70
|
+
try:
|
|
71
|
+
startphrase = str(doc.get("startphrase", "")).strip()
|
|
72
|
+
endings = [doc.get("ending0", ""), doc.get("ending1", ""), doc.get("ending2", ""), doc.get("ending3", "")]
|
|
73
|
+
label = doc.get("label")
|
|
74
|
+
|
|
75
|
+
if not startphrase or not endings or not label in {0, 1, 2, 3}:
|
|
76
|
+
log.debug(
|
|
77
|
+
"Skipping doc due to missing/invalid fields",
|
|
78
|
+
extra={"doc": doc},
|
|
79
|
+
)
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
correct = endings[label]
|
|
83
|
+
incorrect = endings[(label+1)%len(endings)]
|
|
84
|
+
|
|
85
|
+
question = f"{startphrase}"
|
|
86
|
+
formatted_question = f"{question}\nA. {incorrect}\nB. {correct}"
|
|
87
|
+
|
|
88
|
+
metadata = {
|
|
89
|
+
"label": "swag",
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
return self._build_pair(
|
|
93
|
+
question=formatted_question,
|
|
94
|
+
correct=correct,
|
|
95
|
+
incorrect=incorrect,
|
|
96
|
+
metadata=metadata,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
except Exception as exc:
|
|
100
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
@staticmethod
|
|
104
|
+
def _build_pair(
|
|
105
|
+
question: str,
|
|
106
|
+
correct: str,
|
|
107
|
+
incorrect: str,
|
|
108
|
+
metadata: dict[str, Any] | None = None,
|
|
109
|
+
) -> ContrastivePair:
|
|
110
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
111
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
112
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|