wisent 0.5.11__py3-none-any.whl → 0.5.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -1
- wisent/core/activations/__init__.py +26 -0
- wisent/core/activations/activations.py +96 -0
- wisent/core/activations/activations_collector.py +71 -20
- wisent/core/activations/prompt_construction_strategy.py +47 -0
- wisent/core/agent/budget.py +2 -2
- wisent/core/agent/device_benchmarks.py +1 -1
- wisent/core/agent/diagnose/classifier_marketplace.py +8 -8
- wisent/core/agent/diagnose/response_diagnostics.py +4 -4
- wisent/core/agent/diagnose/synthetic_classifier_option.py +1 -1
- wisent/core/agent/diagnose/tasks/task_manager.py +3 -3
- wisent/core/agent/diagnose.py +2 -1
- wisent/core/autonomous_agent.py +10 -2
- wisent/core/benchmark_extractors.py +293 -0
- wisent/core/bigcode_integration.py +20 -7
- wisent/core/branding.py +108 -0
- wisent/core/cli/__init__.py +15 -0
- wisent/core/cli/create_steering_vector.py +138 -0
- wisent/core/cli/evaluate_responses.py +715 -0
- wisent/core/cli/generate_pairs.py +128 -0
- wisent/core/cli/generate_pairs_from_task.py +119 -0
- wisent/core/cli/generate_responses.py +129 -0
- wisent/core/cli/generate_vector_from_synthetic.py +149 -0
- wisent/core/cli/generate_vector_from_task.py +147 -0
- wisent/core/cli/get_activations.py +191 -0
- wisent/core/cli/optimize_classification.py +339 -0
- wisent/core/cli/optimize_steering.py +364 -0
- wisent/core/cli/tasks.py +182 -0
- wisent/core/cli_logger.py +22 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +27 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +49 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +118 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +146 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +129 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gsm8k.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/livecodebench.py +367 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +121 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +121 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +110 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/squad2.py +124 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sst2.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +127 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_gen.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +117 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +117 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +127 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +1 -1
- wisent/core/data_loaders/__init__.py +235 -0
- wisent/core/data_loaders/loaders/lm_loader.py +2 -2
- wisent/core/data_loaders/loaders/task_interface_loader.py +300 -0
- wisent/{cli/data_loaders/data_loader_rotator.py → core/data_loaders/rotator.py} +1 -1
- wisent/core/download_full_benchmarks.py +79 -2
- wisent/core/evaluators/benchmark_specific/__init__.py +26 -0
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/evaluator.py +17 -17
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/cpp_sanitizer.py +2 -2
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/java_sanitizer.py +2 -2
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/python_sanitizer.py +2 -2
- wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/__init__.py +3 -0
- wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/provider.py +305 -0
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/runtime.py +36 -4
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/entrypoint.py +2 -4
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/recipes.py +1 -1
- wisent/core/evaluators/benchmark_specific/coding/solution_generator.py +258 -0
- wisent/core/evaluators/benchmark_specific/exact_match_evaluator.py +79 -0
- wisent/core/evaluators/benchmark_specific/f1_evaluator.py +101 -0
- wisent/core/evaluators/benchmark_specific/generation_evaluator.py +197 -0
- wisent/core/{log_likelihoods_evaluator.py → evaluators/benchmark_specific/log_likelihoods_evaluator.py} +10 -2
- wisent/core/evaluators/benchmark_specific/perplexity_evaluator.py +140 -0
- wisent/core/evaluators/benchmark_specific/personalization_evaluator.py +250 -0
- wisent/{cli/evaluators/evaluator_rotator.py → core/evaluators/rotator.py} +4 -4
- wisent/core/lm_eval_harness_ground_truth.py +3 -2
- wisent/core/main.py +57 -0
- wisent/core/model_persistence.py +2 -2
- wisent/core/models/wisent_model.py +8 -6
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
- wisent/core/optuna/steering/steering_optimization.py +1 -1
- wisent/core/parser_arguments/__init__.py +10 -0
- wisent/core/parser_arguments/agent_parser.py +110 -0
- wisent/core/parser_arguments/configure_model_parser.py +7 -0
- wisent/core/parser_arguments/create_steering_vector_parser.py +59 -0
- wisent/core/parser_arguments/evaluate_parser.py +40 -0
- wisent/core/parser_arguments/evaluate_responses_parser.py +10 -0
- wisent/core/parser_arguments/full_optimize_parser.py +115 -0
- wisent/core/parser_arguments/generate_pairs_from_task_parser.py +33 -0
- wisent/core/parser_arguments/generate_pairs_parser.py +29 -0
- wisent/core/parser_arguments/generate_responses_parser.py +15 -0
- wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +127 -0
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +127 -0
- wisent/core/parser_arguments/generate_vector_parser.py +90 -0
- wisent/core/parser_arguments/get_activations_parser.py +90 -0
- wisent/core/parser_arguments/main_parser.py +152 -0
- wisent/core/parser_arguments/model_config_parser.py +59 -0
- wisent/core/parser_arguments/monitor_parser.py +17 -0
- wisent/core/parser_arguments/multi_steer_parser.py +47 -0
- wisent/core/parser_arguments/optimize_classification_parser.py +67 -0
- wisent/core/parser_arguments/optimize_sample_size_parser.py +58 -0
- wisent/core/parser_arguments/optimize_steering_parser.py +147 -0
- wisent/core/parser_arguments/synthetic_parser.py +93 -0
- wisent/core/parser_arguments/tasks_parser.py +584 -0
- wisent/core/parser_arguments/test_nonsense_parser.py +26 -0
- wisent/core/parser_arguments/utils.py +111 -0
- wisent/core/prompts/core/prompt_formater.py +3 -3
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +2 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +2 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +2 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +2 -0
- wisent/{cli/steering_methods/steering_rotator.py → core/steering_methods/rotator.py} +4 -4
- wisent/core/steering_optimizer.py +45 -21
- wisent/{synthetic → core/synthetic}/cleaners/deduper_cleaner.py +3 -3
- wisent/{synthetic → core/synthetic}/cleaners/methods/base_dedupers.py +2 -2
- wisent/{synthetic → core/synthetic}/cleaners/methods/base_refusalers.py +1 -1
- wisent/{synthetic → core/synthetic}/cleaners/pairs_cleaner.py +5 -5
- wisent/{synthetic → core/synthetic}/cleaners/refusaler_cleaner.py +4 -4
- wisent/{synthetic → core/synthetic}/db_instructions/mini_dp.py +1 -1
- wisent/{synthetic → core/synthetic}/generators/diversities/methods/fast_diversity.py +1 -1
- wisent/{synthetic → core/synthetic}/generators/pairs_generator.py +38 -12
- wisent/core/tasks/livecodebench_task.py +4 -103
- wisent/core/timing_calibration.py +1 -1
- {wisent-0.5.11.dist-info → wisent-0.5.13.dist-info}/METADATA +3 -3
- wisent-0.5.13.dist-info/RECORD +294 -0
- wisent-0.5.13.dist-info/entry_points.txt +2 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +0 -53
- wisent/classifiers/core/atoms.py +0 -747
- wisent/classifiers/models/logistic.py +0 -29
- wisent/classifiers/models/mlp.py +0 -47
- wisent/cli/classifiers/classifier_rotator.py +0 -137
- wisent/cli/cli_logger.py +0 -142
- wisent/cli/wisent_cli/commands/help_cmd.py +0 -52
- wisent/cli/wisent_cli/commands/listing.py +0 -154
- wisent/cli/wisent_cli/commands/train_cmd.py +0 -322
- wisent/cli/wisent_cli/main.py +0 -93
- wisent/cli/wisent_cli/shell.py +0 -80
- wisent/cli/wisent_cli/ui.py +0 -69
- wisent/cli/wisent_cli/util/aggregations.py +0 -43
- wisent/cli/wisent_cli/util/parsing.py +0 -126
- wisent/cli/wisent_cli/version.py +0 -4
- wisent/opti/methods/__init__.py +0 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent-0.5.11.dist-info/RECORD +0 -220
- /wisent/{benchmarks → core/evaluators/benchmark_specific/coding}/__init__.py +0 -0
- /wisent/{benchmarks/coding → core/evaluators/benchmark_specific/coding/metrics}/__init__.py +0 -0
- /wisent/{benchmarks/coding/metrics → core/evaluators/benchmark_specific/coding/metrics/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/core/atoms.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/passk.py +0 -0
- /wisent/{benchmarks/coding/metrics/core → core/evaluators/benchmark_specific/coding/output_sanitizer}/__init__.py +0 -0
- /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/output_sanitizer/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/core/atoms.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/utils.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/__init__.py +0 -0
- /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/providers}/core/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/core/atoms.py +0 -0
- /wisent/{benchmarks/coding/providers/core → core/evaluators/benchmark_specific/coding/safe_docker}/__init__.py +0 -0
- /wisent/{benchmarks/coding/providers/livecodebench → core/evaluators/benchmark_specific/coding/safe_docker/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/atoms.py +0 -0
- /wisent/{benchmarks/coding/safe_docker → core/opti}/__init__.py +0 -0
- /wisent/{benchmarks/coding/safe_docker → core/opti}/core/__init__.py +0 -0
- /wisent/{opti → core/opti}/core/atoms.py +0 -0
- /wisent/{classifiers → core/opti/methods}/__init__.py +0 -0
- /wisent/{opti → core/opti}/methods/opti_classificator.py +0 -0
- /wisent/{opti → core/opti}/methods/opti_steering.py +0 -0
- /wisent/{classifiers/core → core/synthetic}/__init__.py +0 -0
- /wisent/{classifiers/models → core/synthetic/cleaners}/__init__.py +0 -0
- /wisent/{cli → core/synthetic/cleaners/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/cleaners/core/atoms.py +0 -0
- /wisent/{cli/classifiers → core/synthetic/cleaners/methods}/__init__.py +0 -0
- /wisent/{cli/data_loaders → core/synthetic/cleaners/methods/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/cleaners/methods/core/atoms.py +0 -0
- /wisent/{cli/evaluators → core/synthetic/db_instructions}/__init__.py +0 -0
- /wisent/{cli/steering_methods → core/synthetic/db_instructions/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/db_instructions/core/atoms.py +0 -0
- /wisent/{cli/wisent_cli → core/synthetic/generators}/__init__.py +0 -0
- /wisent/{cli/wisent_cli/commands → core/synthetic/generators/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/generators/core/atoms.py +0 -0
- /wisent/{cli/wisent_cli/util → core/synthetic/generators/diversities}/__init__.py +0 -0
- /wisent/{opti → core/synthetic/generators/diversities/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/generators/diversities/core/core.py +0 -0
- /wisent/{opti/core → core/synthetic/generators/diversities/methods}/__init__.py +0 -0
- {wisent-0.5.11.dist-info → wisent-0.5.13.dist-info}/WHEEL +0 -0
- {wisent-0.5.11.dist-info → wisent-0.5.13.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.5.11.dist-info → wisent-0.5.13.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
6
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
7
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
8
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from lm_eval.api.task import ConfigurableTask
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = ["CBExtractor"]
|
|
15
|
+
_LOG = setup_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CBExtractor(LMEvalBenchmarkExtractor):
|
|
19
|
+
"""Extractor for the CB benchmark."""
|
|
20
|
+
|
|
21
|
+
def extract_contrastive_pairs(
|
|
22
|
+
self,
|
|
23
|
+
lm_eval_task_data: ConfigurableTask,
|
|
24
|
+
limit: int | None = None,
|
|
25
|
+
preferred_doc: str | None = None,
|
|
26
|
+
) -> list[ContrastivePair]:
|
|
27
|
+
"""
|
|
28
|
+
Build contrastive pairs from CB docs.
|
|
29
|
+
|
|
30
|
+
CB schema:
|
|
31
|
+
- premise: str
|
|
32
|
+
- hypothesis: str
|
|
33
|
+
- label: 0 or 1 or 2, 0 for "True", 1 for "False", 2 for "Neither"
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
lm_eval_task_data: lm-eval task instance for CB.
|
|
37
|
+
limit: Optional maximum number of pairs to produce.
|
|
38
|
+
preferred_doc: Preferred document source ("validation", "test", "training", "fewshot")
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
A list of ContrastivePair objects.
|
|
42
|
+
"""
|
|
43
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
44
|
+
|
|
45
|
+
max_items = self._normalize_limit(limit)
|
|
46
|
+
docs = self.load_docs(lm_eval_task_data, max_items, preferred_doc=preferred_doc)
|
|
47
|
+
|
|
48
|
+
pairs: list[ContrastivePair] = []
|
|
49
|
+
|
|
50
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
51
|
+
|
|
52
|
+
for doc in docs:
|
|
53
|
+
pair = self._extract_pair_from_doc(doc)
|
|
54
|
+
if pair is not None:
|
|
55
|
+
pairs.append(pair)
|
|
56
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
57
|
+
break
|
|
58
|
+
|
|
59
|
+
if not pairs:
|
|
60
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
61
|
+
log.warning("No valid CB pairs extracted", extra={"task": task_name})
|
|
62
|
+
|
|
63
|
+
return pairs
|
|
64
|
+
|
|
65
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
66
|
+
"""
|
|
67
|
+
Convert a single CB doc into a ContrastivePair, if possible.
|
|
68
|
+
Returns None when required fields are missing or malformed.
|
|
69
|
+
"""
|
|
70
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
premise = str(doc.get("premise", "")).strip()
|
|
74
|
+
hypothesis = str(doc.get("hypothesis", "")).strip()
|
|
75
|
+
label = doc.get("label")
|
|
76
|
+
|
|
77
|
+
if not premise or not hypothesis or label not in {0, 1, 2}:
|
|
78
|
+
log.debug(
|
|
79
|
+
"Skipping doc due to missing/invalid fields",
|
|
80
|
+
extra={"doc": doc},
|
|
81
|
+
)
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
labels = {0: "True", 1: "False", 2: "Neither"}
|
|
85
|
+
correct = labels[label]
|
|
86
|
+
incorrect = labels[(label+1)%3]
|
|
87
|
+
|
|
88
|
+
formatted_question = f"{premise}\nQuestion: {hypothesis}. True, False, or Neither?"
|
|
89
|
+
|
|
90
|
+
metadata = {
|
|
91
|
+
"label": "cb",
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
return self._build_pair(
|
|
95
|
+
question=formatted_question,
|
|
96
|
+
correct=correct,
|
|
97
|
+
incorrect=incorrect,
|
|
98
|
+
metadata=metadata,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
except Exception as exc:
|
|
102
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
@staticmethod
|
|
106
|
+
def _build_pair(
|
|
107
|
+
question: str,
|
|
108
|
+
correct: str,
|
|
109
|
+
incorrect: str,
|
|
110
|
+
metadata: dict[str, Any] | None = None,
|
|
111
|
+
) -> ContrastivePair:
|
|
112
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
113
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
114
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
6
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
7
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
8
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from lm_eval.api.task import ConfigurableTask
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = ["CopaExtractor"]
|
|
15
|
+
_LOG = setup_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class COPAExtractor(LMEvalBenchmarkExtractor):
|
|
19
|
+
"""Extractor for the COPA benchmark."""
|
|
20
|
+
|
|
21
|
+
def extract_contrastive_pairs(
|
|
22
|
+
self,
|
|
23
|
+
lm_eval_task_data: ConfigurableTask,
|
|
24
|
+
limit: int | None = None,
|
|
25
|
+
) -> list[ContrastivePair]:
|
|
26
|
+
"""
|
|
27
|
+
Build contrastive pairs from COPA docs.
|
|
28
|
+
|
|
29
|
+
COPA schema:
|
|
30
|
+
- premise: str
|
|
31
|
+
- choice1, choice2: str
|
|
32
|
+
- question: str
|
|
33
|
+
- label: 0 or 1
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
lm_eval_task_data: lm-eval task instance for COPA.
|
|
37
|
+
limit: Optional maximum number of pairs to produce.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A list of ContrastivePair objects.
|
|
41
|
+
"""
|
|
42
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
43
|
+
|
|
44
|
+
max_items = self._normalize_limit(limit)
|
|
45
|
+
docs = self.load_docs(lm_eval_task_data, max_items)
|
|
46
|
+
|
|
47
|
+
pairs: list[ContrastivePair] = []
|
|
48
|
+
|
|
49
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
50
|
+
|
|
51
|
+
for doc in docs:
|
|
52
|
+
pair = self._extract_pair_from_doc(doc)
|
|
53
|
+
if pair is not None:
|
|
54
|
+
pairs.append(pair)
|
|
55
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
56
|
+
break
|
|
57
|
+
|
|
58
|
+
if not pairs:
|
|
59
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
60
|
+
log.warning("No valid COPA pairs extracted", extra={"task": task_name})
|
|
61
|
+
|
|
62
|
+
return pairs
|
|
63
|
+
|
|
64
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
65
|
+
"""
|
|
66
|
+
Convert a single COPA doc into a ContrastivePair, if possible.
|
|
67
|
+
Returns None when required fields are missing or malformed.
|
|
68
|
+
"""
|
|
69
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
premise = str(doc.get("premise", "")).strip()
|
|
73
|
+
choice1 = str(doc.get("choice1", "")).strip()
|
|
74
|
+
choice2 = str(doc.get("choice2", "")).strip()
|
|
75
|
+
question = str(doc.get("question", "")).strip()
|
|
76
|
+
label = doc.get("label")
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
if not premise or not choice1 or not choice2 or not question or not label in {0, 1}:
|
|
80
|
+
log.debug(
|
|
81
|
+
"Skipping doc due to missing/invalid fields",
|
|
82
|
+
extra={"doc": doc},
|
|
83
|
+
)
|
|
84
|
+
return None
|
|
85
|
+
|
|
86
|
+
fills = {"cause": "because", "effect": "therefore"}
|
|
87
|
+
|
|
88
|
+
question = f"{premise.rstrip('.')} {fills[question]}"
|
|
89
|
+
formatted_question = f"{question}\nA. {choice1}\nB. {choice2}"
|
|
90
|
+
|
|
91
|
+
correct = choice1 if label == 0 else choice2
|
|
92
|
+
incorrect = choice2 if label == 0 else choice1
|
|
93
|
+
|
|
94
|
+
metadata = {
|
|
95
|
+
"label": "copa",
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
return self._build_pair(
|
|
99
|
+
question=formatted_question,
|
|
100
|
+
correct=correct,
|
|
101
|
+
incorrect=incorrect,
|
|
102
|
+
metadata=metadata,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
except Exception as exc:
|
|
106
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
@staticmethod
|
|
110
|
+
def _build_pair(
|
|
111
|
+
question: str,
|
|
112
|
+
correct: str,
|
|
113
|
+
incorrect: str,
|
|
114
|
+
metadata: dict[str, Any] | None = None,
|
|
115
|
+
) -> ContrastivePair:
|
|
116
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
117
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
118
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
from typing import Any, TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
7
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
8
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
9
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from lm_eval.api.task import ConfigurableTask
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
__all__ = ["CoQAExtractor"]
|
|
16
|
+
_LOG = setup_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class CoQAExtractor(LMEvalBenchmarkExtractor):
|
|
20
|
+
"""Extractor for the CoQA benchmark."""
|
|
21
|
+
|
|
22
|
+
def extract_contrastive_pairs(
|
|
23
|
+
self,
|
|
24
|
+
lm_eval_task_data: ConfigurableTask,
|
|
25
|
+
limit: int | None = None,
|
|
26
|
+
) -> list[ContrastivePair]:
|
|
27
|
+
"""
|
|
28
|
+
Build contrastive pairs from CoQA docs.
|
|
29
|
+
|
|
30
|
+
CoQA schema:
|
|
31
|
+
- story: str
|
|
32
|
+
- questions: list
|
|
33
|
+
- answers: list
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
lm_eval_task_data: lm-eval task instance for CoQA.
|
|
37
|
+
limit: Optional maximum number of pairs to produce.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A list of ContrastivePair objects.
|
|
41
|
+
"""
|
|
42
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
43
|
+
|
|
44
|
+
max_items = self._normalize_limit(limit)
|
|
45
|
+
docs = self.load_docs(lm_eval_task_data, max_items)
|
|
46
|
+
|
|
47
|
+
pairs: list[ContrastivePair] = []
|
|
48
|
+
|
|
49
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
50
|
+
|
|
51
|
+
for doc in docs:
|
|
52
|
+
pair = self._extract_pair_from_doc(doc)
|
|
53
|
+
if pair is not None:
|
|
54
|
+
pairs.append(pair)
|
|
55
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
56
|
+
break
|
|
57
|
+
|
|
58
|
+
if not pairs:
|
|
59
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
60
|
+
log.warning("No valid CoQA pairs extracted", extra={"task": task_name})
|
|
61
|
+
|
|
62
|
+
return pairs
|
|
63
|
+
|
|
64
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
65
|
+
"""
|
|
66
|
+
Convert a single CoQA doc into a ContrastivePair, if possible.
|
|
67
|
+
Returns None when required fields are missing or malformed.
|
|
68
|
+
"""
|
|
69
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
story = str(doc.get("story", ""))
|
|
73
|
+
questions = doc.get("questions", {})
|
|
74
|
+
answers = doc.get("answers", {})
|
|
75
|
+
|
|
76
|
+
if not story or not questions or not answers:
|
|
77
|
+
log.debug(
|
|
78
|
+
"Skipping doc due to missing/invalid fields",
|
|
79
|
+
extra={"doc": doc},
|
|
80
|
+
)
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
qs = questions["input_text"]
|
|
84
|
+
asw = answers["input_text"]
|
|
85
|
+
|
|
86
|
+
lines = []
|
|
87
|
+
lines.append(story.strip())
|
|
88
|
+
lines.append("")
|
|
89
|
+
|
|
90
|
+
pairs_count = max(0, min(len(qs) - 1, len(asw)))
|
|
91
|
+
for q, a in zip(qs[:pairs_count], asw[:pairs_count]):
|
|
92
|
+
lines.append(f"Q: {q}")
|
|
93
|
+
lines.append(f"A: {a}")
|
|
94
|
+
|
|
95
|
+
if qs:
|
|
96
|
+
lines.append(f"Q: {qs[-1]}")
|
|
97
|
+
|
|
98
|
+
formatted_question = "\n".join(lines)
|
|
99
|
+
|
|
100
|
+
correct = asw[-1] if len(asw) == len(qs) else "no"
|
|
101
|
+
incorrect = None
|
|
102
|
+
# Generate incorrect answer
|
|
103
|
+
try:
|
|
104
|
+
# Try to convert to number
|
|
105
|
+
num = float(correct)
|
|
106
|
+
# Check if it's an integer
|
|
107
|
+
if num.is_integer():
|
|
108
|
+
incorrect = str(int(num) + 1)
|
|
109
|
+
else:
|
|
110
|
+
incorrect = str(num + 1)
|
|
111
|
+
except ValueError:
|
|
112
|
+
# It's a string, shuffle the letters until different
|
|
113
|
+
letters = list(correct)
|
|
114
|
+
incorrect = correct
|
|
115
|
+
random.shuffle(letters)
|
|
116
|
+
incorrect = ''.join(letters)
|
|
117
|
+
if incorrect == correct:
|
|
118
|
+
incorrect += "k"
|
|
119
|
+
|
|
120
|
+
formatted_question = f"{formatted_question}\nA:\nA. {incorrect}\nB. {correct}"
|
|
121
|
+
|
|
122
|
+
metadata = {
|
|
123
|
+
"label": "coqa",
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
return self._build_pair(
|
|
127
|
+
question=formatted_question,
|
|
128
|
+
correct=correct,
|
|
129
|
+
incorrect=incorrect,
|
|
130
|
+
metadata=metadata,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
except Exception as exc:
|
|
134
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
135
|
+
return None
|
|
136
|
+
|
|
137
|
+
@staticmethod
|
|
138
|
+
def _build_pair(
|
|
139
|
+
question: str,
|
|
140
|
+
correct: str,
|
|
141
|
+
incorrect: str,
|
|
142
|
+
metadata: dict[str, Any] | None = None,
|
|
143
|
+
) -> ContrastivePair:
|
|
144
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
145
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
146
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
from typing import Any, TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
7
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
8
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
9
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from lm_eval.api.task import ConfigurableTask
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
__all__ = ["DROPExtractor"]
|
|
16
|
+
_LOG = setup_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DROPExtractor(LMEvalBenchmarkExtractor):
|
|
20
|
+
"""Extractor for the DROP benchmark."""
|
|
21
|
+
|
|
22
|
+
def extract_contrastive_pairs(
|
|
23
|
+
self,
|
|
24
|
+
lm_eval_task_data: ConfigurableTask,
|
|
25
|
+
limit: int | None = None,
|
|
26
|
+
) -> list[ContrastivePair]:
|
|
27
|
+
"""
|
|
28
|
+
Build contrastive pairs from DROP docs.
|
|
29
|
+
|
|
30
|
+
DROP schema:
|
|
31
|
+
- paassage: str
|
|
32
|
+
- question: str
|
|
33
|
+
- answers: list of lists
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
lm_eval_task_data: lm-eval task instance for DROP.
|
|
37
|
+
limit: Optional maximum number of pairs to produce.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A list of ContrastivePair objects.
|
|
41
|
+
"""
|
|
42
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
43
|
+
|
|
44
|
+
max_items = self._normalize_limit(limit)
|
|
45
|
+
docs = self.load_docs(lm_eval_task_data, max_items)
|
|
46
|
+
|
|
47
|
+
pairs: list[ContrastivePair] = []
|
|
48
|
+
|
|
49
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
50
|
+
|
|
51
|
+
for doc in docs:
|
|
52
|
+
pair = self._extract_pair_from_doc(doc)
|
|
53
|
+
if pair is not None:
|
|
54
|
+
pairs.append(pair)
|
|
55
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
56
|
+
break
|
|
57
|
+
|
|
58
|
+
if not pairs:
|
|
59
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
60
|
+
log.warning("No valid DROP pairs extracted", extra={"task": task_name})
|
|
61
|
+
|
|
62
|
+
return pairs
|
|
63
|
+
|
|
64
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
65
|
+
"""
|
|
66
|
+
Convert a single DROP doc into a ContrastivePair, if possible.
|
|
67
|
+
Returns None when required fields are missing or malformed.
|
|
68
|
+
"""
|
|
69
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
passage = str(doc.get("passage", "")).strip()
|
|
73
|
+
question = str(doc.get("question", "")).strip()
|
|
74
|
+
answers = doc.get("answers")
|
|
75
|
+
answer = answers[0]
|
|
76
|
+
|
|
77
|
+
if not passage or not question or not answer:
|
|
78
|
+
log.debug(
|
|
79
|
+
"Skipping doc due to missing/invalid fields",
|
|
80
|
+
extra={"doc": doc},
|
|
81
|
+
)
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
correct = answer[0]
|
|
85
|
+
|
|
86
|
+
# Generate incorrect answer
|
|
87
|
+
try:
|
|
88
|
+
# Try to convert to number
|
|
89
|
+
num = float(correct)
|
|
90
|
+
# Check if it's an integer
|
|
91
|
+
if num.is_integer():
|
|
92
|
+
incorrect = str(int(num) + 1)
|
|
93
|
+
else:
|
|
94
|
+
incorrect = str(num + 1)
|
|
95
|
+
except ValueError:
|
|
96
|
+
# It's a string, shuffle the letters
|
|
97
|
+
letters = list(correct)
|
|
98
|
+
random.shuffle(letters)
|
|
99
|
+
incorrect = ''.join(letters)
|
|
100
|
+
if correct == incorrect:
|
|
101
|
+
incorrect += "k"
|
|
102
|
+
|
|
103
|
+
formatted_question = f"{passage} {question}\nA. {incorrect}\nB. {correct}"
|
|
104
|
+
|
|
105
|
+
metadata = {
|
|
106
|
+
"label": "drop",
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
return self._build_pair(
|
|
110
|
+
question=formatted_question,
|
|
111
|
+
correct=correct,
|
|
112
|
+
incorrect=incorrect,
|
|
113
|
+
metadata=metadata,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
except Exception as exc:
|
|
117
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
118
|
+
return None
|
|
119
|
+
|
|
120
|
+
@staticmethod
|
|
121
|
+
def _build_pair(
|
|
122
|
+
question: str,
|
|
123
|
+
correct: str,
|
|
124
|
+
incorrect: str,
|
|
125
|
+
metadata: dict[str, Any] | None = None,
|
|
126
|
+
) -> ContrastivePair:
|
|
127
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
128
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
129
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
from typing import Any, TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
7
|
+
from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
8
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
9
|
+
from wisent.core.cli_logger import setup_logger, bind
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from lm_eval.api.task import ConfigurableTask
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
__all__ = ["GSM8KExtractor"]
|
|
16
|
+
_LOG = setup_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class GSM8KExtractor(LMEvalBenchmarkExtractor):
|
|
20
|
+
"""Extractor for the GSM8K benchmark."""
|
|
21
|
+
|
|
22
|
+
def extract_contrastive_pairs(
|
|
23
|
+
self,
|
|
24
|
+
lm_eval_task_data: ConfigurableTask,
|
|
25
|
+
limit: int | None = None,
|
|
26
|
+
preferred_doc: str | None = None,
|
|
27
|
+
) -> list[ContrastivePair]:
|
|
28
|
+
"""
|
|
29
|
+
Build contrastive pairs from GSM8K docs.
|
|
30
|
+
|
|
31
|
+
GSM8K schema:
|
|
32
|
+
- question: str
|
|
33
|
+
- answer: str
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
lm_eval_task_data: lm-eval task instance for GSM8K.
|
|
37
|
+
limit: Optional maximum number of pairs to produce.
|
|
38
|
+
preferred_doc: Optional preferred document source ("validation", "test", "training", "fewshot").
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
A list of ContrastivePair objects.
|
|
42
|
+
"""
|
|
43
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
44
|
+
|
|
45
|
+
max_items = self._normalize_limit(limit)
|
|
46
|
+
docs = self.load_docs(lm_eval_task_data, max_items, preferred_doc=preferred_doc)
|
|
47
|
+
|
|
48
|
+
pairs: list[ContrastivePair] = []
|
|
49
|
+
|
|
50
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
51
|
+
|
|
52
|
+
for doc in docs:
|
|
53
|
+
pair = self._extract_pair_from_doc(doc)
|
|
54
|
+
if pair is not None:
|
|
55
|
+
pairs.append(pair)
|
|
56
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
57
|
+
break
|
|
58
|
+
|
|
59
|
+
if not pairs:
|
|
60
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
61
|
+
log.warning("No valid GSM8K pairs extracted", extra={"task": task_name})
|
|
62
|
+
|
|
63
|
+
return pairs
|
|
64
|
+
|
|
65
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
66
|
+
"""
|
|
67
|
+
Convert a single GSM8K doc into a ContrastivePair, if possible.
|
|
68
|
+
Returns None when required fields are missing or malformed.
|
|
69
|
+
"""
|
|
70
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
question = str(doc.get("question", "")).strip()
|
|
74
|
+
answer = str(doc.get("answer", "")).strip()
|
|
75
|
+
|
|
76
|
+
if not question or not answer:
|
|
77
|
+
log.debug(
|
|
78
|
+
"Skipping doc due to missing/invalid fields",
|
|
79
|
+
extra={"doc": doc},
|
|
80
|
+
)
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
numerical_answer = answer
|
|
84
|
+
if "####" not in answer:
|
|
85
|
+
log.debug("Skipping doc due to missing numerical answer", extra={"doc": doc})
|
|
86
|
+
return None
|
|
87
|
+
numerical_answer = answer.split("####")[-1].strip()
|
|
88
|
+
|
|
89
|
+
correct = numerical_answer
|
|
90
|
+
incorrect_val = float(numerical_answer.replace(',', '')) + 1
|
|
91
|
+
incorrect = str(int(incorrect_val)) if incorrect_val == int(incorrect_val) else str(incorrect_val)
|
|
92
|
+
|
|
93
|
+
formatted_question = f"Question: {question}"
|
|
94
|
+
|
|
95
|
+
metadata = {
|
|
96
|
+
"label": "gsm8k",
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
return self._build_pair(
|
|
100
|
+
question=formatted_question,
|
|
101
|
+
correct=correct,
|
|
102
|
+
incorrect=incorrect,
|
|
103
|
+
metadata=metadata,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
except Exception as exc:
|
|
107
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
108
|
+
return None
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def _build_pair(
|
|
112
|
+
question: str,
|
|
113
|
+
correct: str,
|
|
114
|
+
incorrect: str,
|
|
115
|
+
metadata: dict[str, Any] | None = None,
|
|
116
|
+
) -> ContrastivePair:
|
|
117
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
118
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
119
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|