wisent 0.5.12__py3-none-any.whl → 0.5.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -1
- wisent/core/activations/__init__.py +26 -0
- wisent/core/activations/activations.py +96 -0
- wisent/core/activations/activations_collector.py +71 -20
- wisent/core/activations/prompt_construction_strategy.py +47 -0
- wisent/core/agent/__init__.py +1 -18
- wisent/core/agent/budget.py +2 -2
- wisent/core/agent/device_benchmarks.py +1 -1
- wisent/core/agent/diagnose/__init__.py +1 -55
- wisent/core/agent/diagnose/classifier_marketplace.py +8 -8
- wisent/core/agent/diagnose/response_diagnostics.py +4 -4
- wisent/core/agent/diagnose/synthetic_classifier_option.py +1 -1
- wisent/core/agent/diagnose/tasks/task_manager.py +3 -3
- wisent/core/agent/diagnose.py +2 -1
- wisent/core/autonomous_agent.py +10 -2
- wisent/core/benchmark_extractors.py +293 -0
- wisent/core/bigcode_integration.py +20 -7
- wisent/core/branding.py +108 -0
- wisent/core/cli/__init__.py +15 -0
- wisent/core/cli/create_steering_vector.py +138 -0
- wisent/core/cli/evaluate_responses.py +715 -0
- wisent/core/cli/generate_pairs.py +128 -0
- wisent/core/cli/generate_pairs_from_task.py +119 -0
- wisent/core/cli/generate_responses.py +129 -0
- wisent/core/cli/generate_vector_from_synthetic.py +149 -0
- wisent/core/cli/generate_vector_from_task.py +147 -0
- wisent/core/cli/get_activations.py +191 -0
- wisent/core/cli/optimize_classification.py +339 -0
- wisent/core/cli/optimize_steering.py +364 -0
- wisent/core/cli/tasks.py +182 -0
- wisent/core/cli_logger.py +22 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +27 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +49 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +118 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +146 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +129 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gsm8k.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/livecodebench.py +367 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +121 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +121 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +110 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/squad2.py +124 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sst2.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +127 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_gen.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +117 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +117 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +127 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +1 -1
- wisent/core/data_loaders/__init__.py +235 -0
- wisent/core/data_loaders/loaders/lm_loader.py +2 -2
- wisent/core/data_loaders/loaders/task_interface_loader.py +300 -0
- wisent/{cli/data_loaders/data_loader_rotator.py → core/data_loaders/rotator.py} +1 -1
- wisent/core/download_full_benchmarks.py +79 -2
- wisent/core/evaluators/benchmark_specific/__init__.py +26 -0
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/evaluator.py +17 -17
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/cpp_sanitizer.py +2 -2
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/java_sanitizer.py +2 -2
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/python_sanitizer.py +2 -2
- wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/__init__.py +3 -0
- wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/provider.py +305 -0
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/runtime.py +36 -4
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/entrypoint.py +2 -4
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/recipes.py +1 -1
- wisent/core/evaluators/benchmark_specific/coding/solution_generator.py +258 -0
- wisent/core/evaluators/benchmark_specific/exact_match_evaluator.py +79 -0
- wisent/core/evaluators/benchmark_specific/f1_evaluator.py +101 -0
- wisent/core/evaluators/benchmark_specific/generation_evaluator.py +197 -0
- wisent/core/{log_likelihoods_evaluator.py → evaluators/benchmark_specific/log_likelihoods_evaluator.py} +10 -2
- wisent/core/evaluators/benchmark_specific/perplexity_evaluator.py +140 -0
- wisent/core/evaluators/benchmark_specific/personalization_evaluator.py +250 -0
- wisent/{cli/evaluators/evaluator_rotator.py → core/evaluators/rotator.py} +4 -4
- wisent/core/lm_eval_harness_ground_truth.py +3 -2
- wisent/core/main.py +57 -0
- wisent/core/model_persistence.py +2 -2
- wisent/core/models/wisent_model.py +6 -6
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
- wisent/core/optuna/steering/steering_optimization.py +1 -1
- wisent/core/parser_arguments/__init__.py +10 -0
- wisent/core/parser_arguments/agent_parser.py +110 -0
- wisent/core/parser_arguments/configure_model_parser.py +7 -0
- wisent/core/parser_arguments/create_steering_vector_parser.py +59 -0
- wisent/core/parser_arguments/evaluate_parser.py +40 -0
- wisent/core/parser_arguments/evaluate_responses_parser.py +10 -0
- wisent/core/parser_arguments/full_optimize_parser.py +115 -0
- wisent/core/parser_arguments/generate_pairs_from_task_parser.py +33 -0
- wisent/core/parser_arguments/generate_pairs_parser.py +29 -0
- wisent/core/parser_arguments/generate_responses_parser.py +15 -0
- wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +127 -0
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +127 -0
- wisent/core/parser_arguments/generate_vector_parser.py +90 -0
- wisent/core/parser_arguments/get_activations_parser.py +90 -0
- wisent/core/parser_arguments/main_parser.py +152 -0
- wisent/core/parser_arguments/model_config_parser.py +59 -0
- wisent/core/parser_arguments/monitor_parser.py +17 -0
- wisent/core/parser_arguments/multi_steer_parser.py +47 -0
- wisent/core/parser_arguments/optimize_classification_parser.py +67 -0
- wisent/core/parser_arguments/optimize_sample_size_parser.py +58 -0
- wisent/core/parser_arguments/optimize_steering_parser.py +147 -0
- wisent/core/parser_arguments/synthetic_parser.py +93 -0
- wisent/core/parser_arguments/tasks_parser.py +584 -0
- wisent/core/parser_arguments/test_nonsense_parser.py +26 -0
- wisent/core/parser_arguments/utils.py +111 -0
- wisent/core/prompts/core/prompt_formater.py +3 -3
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +2 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +2 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +2 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +2 -0
- wisent/{cli/steering_methods/steering_rotator.py → core/steering_methods/rotator.py} +4 -4
- wisent/core/steering_optimizer.py +45 -21
- wisent/{synthetic → core/synthetic}/cleaners/deduper_cleaner.py +3 -3
- wisent/{synthetic → core/synthetic}/cleaners/methods/base_dedupers.py +2 -2
- wisent/{synthetic → core/synthetic}/cleaners/methods/base_refusalers.py +1 -1
- wisent/{synthetic → core/synthetic}/cleaners/pairs_cleaner.py +5 -5
- wisent/{synthetic → core/synthetic}/cleaners/refusaler_cleaner.py +4 -4
- wisent/{synthetic → core/synthetic}/db_instructions/mini_dp.py +1 -1
- wisent/{synthetic → core/synthetic}/generators/diversities/methods/fast_diversity.py +1 -1
- wisent/{synthetic → core/synthetic}/generators/pairs_generator.py +38 -12
- wisent/core/tasks/livecodebench_task.py +4 -103
- wisent/core/timing_calibration.py +1 -1
- {wisent-0.5.12.dist-info → wisent-0.5.14.dist-info}/METADATA +3 -3
- wisent-0.5.14.dist-info/RECORD +294 -0
- wisent-0.5.14.dist-info/entry_points.txt +2 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +0 -53
- wisent/classifiers/core/atoms.py +0 -747
- wisent/classifiers/models/logistic.py +0 -29
- wisent/classifiers/models/mlp.py +0 -47
- wisent/cli/classifiers/classifier_rotator.py +0 -137
- wisent/cli/cli_logger.py +0 -142
- wisent/cli/wisent_cli/commands/help_cmd.py +0 -52
- wisent/cli/wisent_cli/commands/listing.py +0 -154
- wisent/cli/wisent_cli/commands/train_cmd.py +0 -322
- wisent/cli/wisent_cli/main.py +0 -93
- wisent/cli/wisent_cli/shell.py +0 -80
- wisent/cli/wisent_cli/ui.py +0 -69
- wisent/cli/wisent_cli/util/aggregations.py +0 -43
- wisent/cli/wisent_cli/util/parsing.py +0 -126
- wisent/cli/wisent_cli/version.py +0 -4
- wisent/opti/methods/__init__.py +0 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent-0.5.12.dist-info/RECORD +0 -220
- /wisent/{benchmarks → core/evaluators/benchmark_specific/coding}/__init__.py +0 -0
- /wisent/{benchmarks/coding → core/evaluators/benchmark_specific/coding/metrics}/__init__.py +0 -0
- /wisent/{benchmarks/coding/metrics → core/evaluators/benchmark_specific/coding/metrics/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/core/atoms.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/passk.py +0 -0
- /wisent/{benchmarks/coding/metrics/core → core/evaluators/benchmark_specific/coding/output_sanitizer}/__init__.py +0 -0
- /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/output_sanitizer/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/core/atoms.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/utils.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/__init__.py +0 -0
- /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/providers}/core/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/core/atoms.py +0 -0
- /wisent/{benchmarks/coding/providers/core → core/evaluators/benchmark_specific/coding/safe_docker}/__init__.py +0 -0
- /wisent/{benchmarks/coding/providers/livecodebench → core/evaluators/benchmark_specific/coding/safe_docker/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/atoms.py +0 -0
- /wisent/{benchmarks/coding/safe_docker → core/opti}/__init__.py +0 -0
- /wisent/{benchmarks/coding/safe_docker → core/opti}/core/__init__.py +0 -0
- /wisent/{opti → core/opti}/core/atoms.py +0 -0
- /wisent/{classifiers → core/opti/methods}/__init__.py +0 -0
- /wisent/{opti → core/opti}/methods/opti_classificator.py +0 -0
- /wisent/{opti → core/opti}/methods/opti_steering.py +0 -0
- /wisent/{classifiers/core → core/synthetic}/__init__.py +0 -0
- /wisent/{classifiers/models → core/synthetic/cleaners}/__init__.py +0 -0
- /wisent/{cli → core/synthetic/cleaners/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/cleaners/core/atoms.py +0 -0
- /wisent/{cli/classifiers → core/synthetic/cleaners/methods}/__init__.py +0 -0
- /wisent/{cli/data_loaders → core/synthetic/cleaners/methods/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/cleaners/methods/core/atoms.py +0 -0
- /wisent/{cli/evaluators → core/synthetic/db_instructions}/__init__.py +0 -0
- /wisent/{cli/steering_methods → core/synthetic/db_instructions/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/db_instructions/core/atoms.py +0 -0
- /wisent/{cli/wisent_cli → core/synthetic/generators}/__init__.py +0 -0
- /wisent/{cli/wisent_cli/commands → core/synthetic/generators/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/generators/core/atoms.py +0 -0
- /wisent/{cli/wisent_cli/util → core/synthetic/generators/diversities}/__init__.py +0 -0
- /wisent/{opti → core/synthetic/generators/diversities/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/generators/diversities/core/core.py +0 -0
- /wisent/{opti/core → core/synthetic/generators/diversities/methods}/__init__.py +0 -0
- {wisent-0.5.12.dist-info → wisent-0.5.14.dist-info}/WHEEL +0 -0
- {wisent-0.5.12.dist-info → wisent-0.5.14.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.5.12.dist-info → wisent-0.5.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
6
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
7
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
8
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from lm_eval.api.task import ConfigurableTask
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = ["WiCExtractor"]
|
|
15
|
+
_LOG = setup_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class WiCExtractor(LMEvalBenchmarkExtractor):
|
|
19
|
+
"""Extractor for the WiC benchmark."""
|
|
20
|
+
|
|
21
|
+
def extract_contrastive_pairs(
|
|
22
|
+
self,
|
|
23
|
+
lm_eval_task_data: ConfigurableTask,
|
|
24
|
+
limit: int | None = None,
|
|
25
|
+
) -> list[ContrastivePair]:
|
|
26
|
+
"""
|
|
27
|
+
Build contrastive pairs from WiC docs.
|
|
28
|
+
|
|
29
|
+
WiC schema:
|
|
30
|
+
- sentence1: str
|
|
31
|
+
- sentence2: str
|
|
32
|
+
- word: str
|
|
33
|
+
- label: 0 or 1
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
lm_eval_task_data: lm-eval task instance for WiC.
|
|
37
|
+
limit: Optional maximum number of pairs to produce.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A list of ContrastivePair objects.
|
|
41
|
+
"""
|
|
42
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
43
|
+
|
|
44
|
+
max_items = self._normalize_limit(limit)
|
|
45
|
+
docs = self.load_docs(lm_eval_task_data, max_items)
|
|
46
|
+
|
|
47
|
+
pairs: list[ContrastivePair] = []
|
|
48
|
+
|
|
49
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
50
|
+
|
|
51
|
+
for doc in docs:
|
|
52
|
+
pair = self._extract_pair_from_doc(doc)
|
|
53
|
+
if pair is not None:
|
|
54
|
+
pairs.append(pair)
|
|
55
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
56
|
+
break
|
|
57
|
+
|
|
58
|
+
if not pairs:
|
|
59
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
60
|
+
log.warning("No valid WiC pairs extracted", extra={"task": task_name})
|
|
61
|
+
|
|
62
|
+
return pairs
|
|
63
|
+
|
|
64
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
65
|
+
"""
|
|
66
|
+
Convert a single WiC doc into a ContrastivePair, if possible.
|
|
67
|
+
Returns None when required fields are missing or malformed.
|
|
68
|
+
"""
|
|
69
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
sentence1 = str(doc.get("sentence1", "")).strip()
|
|
73
|
+
sentence2 = str(doc.get("sentence2", "")).strip()
|
|
74
|
+
word = str(doc.get("word", "")).strip()
|
|
75
|
+
label = doc.get("label")
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
if not sentence1 or not sentence2 or label not in {0, 1}:
|
|
79
|
+
log.debug(
|
|
80
|
+
"Skipping doc due to missing/invalid fields",
|
|
81
|
+
extra={"doc": doc},
|
|
82
|
+
)
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
formatted_question = (
|
|
86
|
+
f"Sentence 1: {sentence1}\n"
|
|
87
|
+
f"Sentence 2: {sentence2}\n"
|
|
88
|
+
f"Question: Is the word '{word}' used in the same way in the two sentences above?\n"
|
|
89
|
+
"Answer:\nA. Yes\nB. No"
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
correct = "Yes" if label == 1 else "No"
|
|
93
|
+
incorrect = "No" if label == 1 else "Yes"
|
|
94
|
+
|
|
95
|
+
metadata = {
|
|
96
|
+
"label": "wic",
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
return self._build_pair(
|
|
100
|
+
question=formatted_question,
|
|
101
|
+
correct=correct,
|
|
102
|
+
incorrect=incorrect,
|
|
103
|
+
metadata=metadata,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
except Exception as exc:
|
|
107
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
108
|
+
return None
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def _build_pair(
|
|
112
|
+
question: str,
|
|
113
|
+
correct: str,
|
|
114
|
+
incorrect: str,
|
|
115
|
+
metadata: dict[str, Any] | None = None,
|
|
116
|
+
) -> ContrastivePair:
|
|
117
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
118
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
119
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|
|
@@ -5,7 +5,7 @@ from typing import Any, TYPE_CHECKING
|
|
|
5
5
|
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
6
6
|
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
7
7
|
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
8
|
-
from wisent.
|
|
8
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
9
9
|
|
|
10
10
|
if TYPE_CHECKING:
|
|
11
11
|
from lm_eval.api.task import ConfigurableTask
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
6
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
7
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
8
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from lm_eval.api.task import ConfigurableTask
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = ["WNLIExtractor"]
|
|
15
|
+
_LOG = setup_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class WNLIExtractor(LMEvalBenchmarkExtractor):
|
|
19
|
+
"""Extractor for the WNLI benchmark."""
|
|
20
|
+
|
|
21
|
+
def extract_contrastive_pairs(
|
|
22
|
+
self,
|
|
23
|
+
lm_eval_task_data: ConfigurableTask,
|
|
24
|
+
limit: int | None = None,
|
|
25
|
+
) -> list[ContrastivePair]:
|
|
26
|
+
"""
|
|
27
|
+
Build contrastive pairs from WNLI docs.
|
|
28
|
+
|
|
29
|
+
WNLI schema:
|
|
30
|
+
- sentence1: str
|
|
31
|
+
- sentence2: str
|
|
32
|
+
- label: 0 or 1
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
lm_eval_task_data: lm-eval task instance for RTE.
|
|
36
|
+
limit: Optional maximum number of pairs to produce.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
A list of ContrastivePair objects.
|
|
40
|
+
"""
|
|
41
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
42
|
+
|
|
43
|
+
max_items = self._normalize_limit(limit)
|
|
44
|
+
docs = self.load_docs(lm_eval_task_data, max_items)
|
|
45
|
+
|
|
46
|
+
pairs: list[ContrastivePair] = []
|
|
47
|
+
|
|
48
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
49
|
+
|
|
50
|
+
for doc in docs:
|
|
51
|
+
pair = self._extract_pair_from_doc(doc)
|
|
52
|
+
if pair is not None:
|
|
53
|
+
pairs.append(pair)
|
|
54
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
55
|
+
break
|
|
56
|
+
|
|
57
|
+
if not pairs:
|
|
58
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
59
|
+
log.warning("No valid WNLI pairs extracted", extra={"task": task_name})
|
|
60
|
+
|
|
61
|
+
return pairs
|
|
62
|
+
|
|
63
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
64
|
+
"""
|
|
65
|
+
Convert a single WNLI doc into a ContrastivePair, if possible.
|
|
66
|
+
Returns None when required fields are missing or malformed.
|
|
67
|
+
"""
|
|
68
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
69
|
+
|
|
70
|
+
try:
|
|
71
|
+
sentence1 = str(doc.get("sentence1", "")).strip()
|
|
72
|
+
sentence2 = str(doc.get("sentence2", "")).strip()
|
|
73
|
+
label = doc.get("label")
|
|
74
|
+
|
|
75
|
+
if not sentence1 or not sentence2 or label not in {0, 1}:
|
|
76
|
+
log.debug(
|
|
77
|
+
"Skipping doc due to missing/invalid fields",
|
|
78
|
+
extra={"doc": doc},
|
|
79
|
+
)
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
formatted_question = f"{sentence1}\nQuestion: {sentence2} True or False?\nAnswer:\nA. True\nB. False"
|
|
83
|
+
|
|
84
|
+
correct = "True" if label == 1 else "False"
|
|
85
|
+
incorrect = "False" if label == 1 else "True"
|
|
86
|
+
|
|
87
|
+
metadata = {
|
|
88
|
+
"label": "wnli",
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
return self._build_pair(
|
|
92
|
+
question=formatted_question,
|
|
93
|
+
correct=correct,
|
|
94
|
+
incorrect=incorrect,
|
|
95
|
+
metadata=metadata,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
except Exception as exc:
|
|
99
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
@staticmethod
|
|
103
|
+
def _build_pair(
|
|
104
|
+
question: str,
|
|
105
|
+
correct: str,
|
|
106
|
+
incorrect: str,
|
|
107
|
+
metadata: dict[str, Any] | None = None,
|
|
108
|
+
) -> ContrastivePair:
|
|
109
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
110
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
111
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
6
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
7
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
8
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from lm_eval.api.task import ConfigurableTask
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = ["WSCExtractor"]
|
|
15
|
+
_LOG = setup_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class WSCExtractor(LMEvalBenchmarkExtractor):
|
|
19
|
+
"""Extractor for the WSC benchmark."""
|
|
20
|
+
|
|
21
|
+
def extract_contrastive_pairs(
|
|
22
|
+
self,
|
|
23
|
+
lm_eval_task_data: ConfigurableTask,
|
|
24
|
+
limit: int | None = None,
|
|
25
|
+
) -> list[ContrastivePair]:
|
|
26
|
+
"""
|
|
27
|
+
Build contrastive pairs from WSC docs.
|
|
28
|
+
|
|
29
|
+
WSC schema:
|
|
30
|
+
- text: str
|
|
31
|
+
- span1_text: str
|
|
32
|
+
- spane2_text: str
|
|
33
|
+
- label: int
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
lm_eval_task_data: lm-eval task instance for WSC.
|
|
37
|
+
limit: Optional maximum number of pairs to produce.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A list of ContrastivePair objects.
|
|
41
|
+
"""
|
|
42
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
43
|
+
|
|
44
|
+
max_items = self._normalize_limit(limit)
|
|
45
|
+
docs = self.load_docs(lm_eval_task_data, max_items)
|
|
46
|
+
|
|
47
|
+
pairs: list[ContrastivePair] = []
|
|
48
|
+
|
|
49
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
50
|
+
|
|
51
|
+
for doc in docs:
|
|
52
|
+
pair = self._extract_pair_from_doc(doc)
|
|
53
|
+
if pair is not None:
|
|
54
|
+
pairs.append(pair)
|
|
55
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
56
|
+
break
|
|
57
|
+
|
|
58
|
+
if not pairs:
|
|
59
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
60
|
+
log.warning("No valid WSC pairs extracted", extra={"task": task_name})
|
|
61
|
+
|
|
62
|
+
return pairs
|
|
63
|
+
|
|
64
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
65
|
+
"""
|
|
66
|
+
Convert a single WSC doc into a ContrastivePair, if possible.
|
|
67
|
+
Returns None when required fields are missing or malformed.
|
|
68
|
+
"""
|
|
69
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
text= str(doc.get("text", "")).strip()
|
|
73
|
+
span1_text = str(doc.get("span1_text", "")).strip()
|
|
74
|
+
span2_text = str(doc.get("span2_text", "")).strip()
|
|
75
|
+
label = doc.get("label")
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
if not text or not span1_text or not span2_text or label not in {0, 1}:
|
|
79
|
+
log.debug(
|
|
80
|
+
"Skipping doc due to missing/invalid fields",
|
|
81
|
+
extra={"doc": doc},
|
|
82
|
+
)
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
formatted_question = f"Passage: {text}\nQuestion:In the passage above, does the pronoun \"*{span2_text}*\" refer to \"*{span1_text}*\"?\nA. Yes\nB. No"
|
|
86
|
+
|
|
87
|
+
correct = "Yes" if label == 1 else "No"
|
|
88
|
+
incorrect = "No" if label == 1 else "Yes"
|
|
89
|
+
|
|
90
|
+
metadata = {
|
|
91
|
+
"label": "wsc",
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
return self._build_pair(
|
|
95
|
+
question=formatted_question,
|
|
96
|
+
correct=correct,
|
|
97
|
+
incorrect=incorrect,
|
|
98
|
+
metadata=metadata,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
except Exception as exc:
|
|
102
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
@staticmethod
|
|
106
|
+
def _build_pair(
|
|
107
|
+
question: str,
|
|
108
|
+
correct: str,
|
|
109
|
+
incorrect: str,
|
|
110
|
+
metadata: dict[str, Any] | None = None,
|
|
111
|
+
) -> ContrastivePair:
|
|
112
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
113
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
114
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
6
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
7
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
8
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from lm_eval.api.task import ConfigurableTask
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = ["XNLIExtractor"]
|
|
15
|
+
_LOG = setup_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class XNLIExtractor(LMEvalBenchmarkExtractor):
|
|
19
|
+
"""Extractor for the XNLI benchmark."""
|
|
20
|
+
|
|
21
|
+
def extract_contrastive_pairs(
|
|
22
|
+
self,
|
|
23
|
+
lm_eval_task_data: ConfigurableTask,
|
|
24
|
+
limit: int | None = None,
|
|
25
|
+
) -> list[ContrastivePair]:
|
|
26
|
+
"""
|
|
27
|
+
Build contrastive pairs from XNLI docs.
|
|
28
|
+
|
|
29
|
+
XNLI schema:
|
|
30
|
+
- premise: str
|
|
31
|
+
- hypothesis: str
|
|
32
|
+
- label: 0 or 1 or 2
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
lm_eval_task_data: lm-eval task instance for XNLI.
|
|
36
|
+
limit: Optional maximum number of pairs to produce.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
A list of ContrastivePair objects.
|
|
40
|
+
"""
|
|
41
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
42
|
+
|
|
43
|
+
max_items = self._normalize_limit(limit)
|
|
44
|
+
docs = self.load_docs(lm_eval_task_data, max_items)
|
|
45
|
+
|
|
46
|
+
pairs: list[ContrastivePair] = []
|
|
47
|
+
|
|
48
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
49
|
+
|
|
50
|
+
for doc in docs:
|
|
51
|
+
pair = self._extract_pair_from_doc(doc)
|
|
52
|
+
if pair is not None:
|
|
53
|
+
pairs.append(pair)
|
|
54
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
55
|
+
break
|
|
56
|
+
|
|
57
|
+
if not pairs:
|
|
58
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
59
|
+
log.warning("No valid XNLI pairs extracted", extra={"task": task_name})
|
|
60
|
+
|
|
61
|
+
return pairs
|
|
62
|
+
|
|
63
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
64
|
+
"""
|
|
65
|
+
Convert a single XNLI doc into a ContrastivePair, if possible.
|
|
66
|
+
Returns None when required fields are missing or malformed.
|
|
67
|
+
"""
|
|
68
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
69
|
+
|
|
70
|
+
try:
|
|
71
|
+
premise = str(doc.get("premise", "")).strip()
|
|
72
|
+
hypothesis = str(doc.get("hypothesis", "")).strip()
|
|
73
|
+
label = doc.get("label")
|
|
74
|
+
|
|
75
|
+
if not premise or not hypothesis or label not in {0, 1, 2}:
|
|
76
|
+
log.debug(
|
|
77
|
+
"Skipping doc due to missing/invalid fields",
|
|
78
|
+
extra={"doc": doc},
|
|
79
|
+
)
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
labels = {0: "entailment", 1: "neutral", 2: "contradiction"}
|
|
83
|
+
correct = labels[label]
|
|
84
|
+
incorrect = labels[(label+1)%3]
|
|
85
|
+
|
|
86
|
+
formatted_question = f"Decide the relationship of the hypothesis '{hypothesis}' to the premise '{premise}\nA. {incorrect}\nB. {correct}"
|
|
87
|
+
|
|
88
|
+
metadata = {
|
|
89
|
+
"label": "xnli",
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
return self._build_pair(
|
|
93
|
+
question=formatted_question,
|
|
94
|
+
correct=correct,
|
|
95
|
+
incorrect=incorrect,
|
|
96
|
+
metadata=metadata,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
except Exception as exc:
|
|
100
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
@staticmethod
|
|
104
|
+
def _build_pair(
|
|
105
|
+
question: str,
|
|
106
|
+
correct: str,
|
|
107
|
+
incorrect: str,
|
|
108
|
+
metadata: dict[str, Any] | None = None,
|
|
109
|
+
) -> ContrastivePair:
|
|
110
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
111
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
112
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
from typing import Any, TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
7
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
8
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
9
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from lm_eval.api.task import ConfigurableTask
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
__all__ = ["XStoryCloze"]
|
|
16
|
+
_LOG = setup_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class XStoryClozeExtractor(LMEvalBenchmarkExtractor):
|
|
20
|
+
"""Extractor for the XStoryCloze benchmark."""
|
|
21
|
+
|
|
22
|
+
def extract_contrastive_pairs(
|
|
23
|
+
self,
|
|
24
|
+
lm_eval_task_data: ConfigurableTask,
|
|
25
|
+
limit: int | None = None,
|
|
26
|
+
) -> list[ContrastivePair]:
|
|
27
|
+
"""
|
|
28
|
+
Build contrastive pairs from XStoryCloze docs.
|
|
29
|
+
|
|
30
|
+
XStoryCloze schema:
|
|
31
|
+
- input_sentence_1, input_sentence_2, input_sentence_3, input_sentence_4: str
|
|
32
|
+
- sentence_quiz1, sentence_quiz2: str
|
|
33
|
+
- answer_right_ending: 1 or 2 or 3 or 4
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
lm_eval_task_data: lm-eval task instance for XStoryCloze.
|
|
37
|
+
limit: Optional maximum number of pairs to produce.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A list of ContrastivePair objects.
|
|
41
|
+
"""
|
|
42
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
43
|
+
|
|
44
|
+
max_items = self._normalize_limit(limit)
|
|
45
|
+
docs = self.load_docs(lm_eval_task_data, max_items)
|
|
46
|
+
|
|
47
|
+
pairs: list[ContrastivePair] = []
|
|
48
|
+
|
|
49
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
50
|
+
|
|
51
|
+
for doc in docs:
|
|
52
|
+
pair = self._extract_pair_from_doc(doc)
|
|
53
|
+
if pair is not None:
|
|
54
|
+
pairs.append(pair)
|
|
55
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
56
|
+
break
|
|
57
|
+
|
|
58
|
+
if not pairs:
|
|
59
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
60
|
+
log.warning("No valid XStoryCloze pairs extracted", extra={"task": task_name})
|
|
61
|
+
|
|
62
|
+
return pairs
|
|
63
|
+
|
|
64
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
65
|
+
"""
|
|
66
|
+
Convert a single XStoryCloze doc into a ContrastivePair, if possible.
|
|
67
|
+
Returns None when required fields are missing or malformed.
|
|
68
|
+
"""
|
|
69
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
inputs = [str(doc.get("input_sentence_1", "")).strip(), str(doc.get("input_sentence_2", "")).strip(),
|
|
73
|
+
str(doc.get("input_sentence_3", "")).strip(), str(doc.get("input_sentence_4", "")).strip()]
|
|
74
|
+
endings = [str(doc.get("sentence_quiz1")).strip(), str(doc.get("sentence_quiz2")).strip()]
|
|
75
|
+
answer = doc.get("answer_right_ending") - 1
|
|
76
|
+
|
|
77
|
+
if not inputs or not endings or not answer:
|
|
78
|
+
log.debug(
|
|
79
|
+
"Skipping doc due to missing/invalid fields",
|
|
80
|
+
extra={"doc": doc},
|
|
81
|
+
)
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
correct = endings[answer]
|
|
85
|
+
incorrect = endings[(answer+1)%len(endings)]
|
|
86
|
+
|
|
87
|
+
formatted_question = " ".join(s.strip() for s in inputs if s)
|
|
88
|
+
formatted_question = f"{formatted_question}\n \nA. {incorrect}\nB. {correct}"
|
|
89
|
+
|
|
90
|
+
metadata = {
|
|
91
|
+
"label": "xstorycloze",
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
return self._build_pair(
|
|
95
|
+
question=formatted_question,
|
|
96
|
+
correct=correct,
|
|
97
|
+
incorrect=incorrect,
|
|
98
|
+
metadata=metadata,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
except Exception as exc:
|
|
102
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
@staticmethod
|
|
106
|
+
def _build_pair(
|
|
107
|
+
question: str,
|
|
108
|
+
correct: str,
|
|
109
|
+
incorrect: str,
|
|
110
|
+
metadata: dict[str, Any] | None = None,
|
|
111
|
+
) -> ContrastivePair:
|
|
112
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
113
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
114
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
from typing import Any, TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
7
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
8
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
9
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from lm_eval.api.task import ConfigurableTask
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
__all__ = ["XWinograd"]
|
|
16
|
+
_LOG = setup_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class XWinogradExtractor(LMEvalBenchmarkExtractor):
|
|
20
|
+
"""Extractor for the XWinograd benchmark."""
|
|
21
|
+
|
|
22
|
+
def extract_contrastive_pairs(
|
|
23
|
+
self,
|
|
24
|
+
lm_eval_task_data: ConfigurableTask,
|
|
25
|
+
limit: int | None = None,
|
|
26
|
+
) -> list[ContrastivePair]:
|
|
27
|
+
"""
|
|
28
|
+
Build contrastive pairs from XWinograd docs.
|
|
29
|
+
|
|
30
|
+
XWinograd schema:
|
|
31
|
+
- sentence: str
|
|
32
|
+
- option1: str
|
|
33
|
+
- option2: str
|
|
34
|
+
- answer: str
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
lm_eval_task_data: lm-eval task instance for XWinograd.
|
|
38
|
+
limit: Optional maximum number of pairs to produce.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
A list of ContrastivePair objects.
|
|
42
|
+
"""
|
|
43
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
44
|
+
|
|
45
|
+
max_items = self._normalize_limit(limit)
|
|
46
|
+
docs = self.load_docs(lm_eval_task_data, max_items)
|
|
47
|
+
|
|
48
|
+
pairs: list[ContrastivePair] = []
|
|
49
|
+
|
|
50
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
51
|
+
|
|
52
|
+
for doc in docs:
|
|
53
|
+
pair = self._extract_pair_from_doc(doc)
|
|
54
|
+
if pair is not None:
|
|
55
|
+
pairs.append(pair)
|
|
56
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
57
|
+
break
|
|
58
|
+
|
|
59
|
+
if not pairs:
|
|
60
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
61
|
+
log.warning("No valid XWinograd pairs extracted", extra={"task": task_name})
|
|
62
|
+
|
|
63
|
+
return pairs
|
|
64
|
+
|
|
65
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
66
|
+
"""
|
|
67
|
+
Convert a single XWinograd doc into a ContrastivePair, if possible.
|
|
68
|
+
Returns None when required fields are missing or malformed.
|
|
69
|
+
"""
|
|
70
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
sentence = str(doc.get("sentence", "")).strip()
|
|
74
|
+
options = [str(doc.get("option1", "")).strip(), str(doc.get("option2", "")).strip()]
|
|
75
|
+
answer = str(doc.get("answer", "")).strip()
|
|
76
|
+
answer = int(answer) - 1
|
|
77
|
+
|
|
78
|
+
if not sentence or not options or answer not in {0, 1}:
|
|
79
|
+
log.debug(
|
|
80
|
+
"Skipping doc due to missing/invalid fields",
|
|
81
|
+
extra={"doc": doc},
|
|
82
|
+
)
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
correct = options[answer]
|
|
86
|
+
incorrect = options[(answer+1)%len(options)]
|
|
87
|
+
|
|
88
|
+
formatted_question = f"Fill in the blank: {sentence}\nA. {incorrect}\nB. {correct}"
|
|
89
|
+
|
|
90
|
+
metadata = {
|
|
91
|
+
"label": "xwinograd",
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
return self._build_pair(
|
|
95
|
+
question=formatted_question,
|
|
96
|
+
correct=correct,
|
|
97
|
+
incorrect=incorrect,
|
|
98
|
+
metadata=metadata,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
except Exception as exc:
|
|
102
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
@staticmethod
|
|
106
|
+
def _build_pair(
|
|
107
|
+
question: str,
|
|
108
|
+
correct: str,
|
|
109
|
+
incorrect: str,
|
|
110
|
+
metadata: dict[str, Any] | None = None,
|
|
111
|
+
) -> ContrastivePair:
|
|
112
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
113
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
114
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|