wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.1.dist-info/METADATA +67 -0
- wisent-0.5.1.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterable, Mapping
|
|
4
|
+
from typing import Any, Sequence, TYPE_CHECKING
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from wisent_guard.core.contrastive_pairs.core.pair import ContrastivePair
|
|
10
|
+
from lm_eval.api.task import ConfigurableTask
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"UnsupportedLMEvalBenchmarkError",
|
|
15
|
+
"NoLabelledDocsAvailableError",
|
|
16
|
+
"LMEvalBenchmarkExtractor",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class UnsupportedLMEvalBenchmarkError(Exception):
|
|
21
|
+
"""Raised when a benchmark/task does not have a compatible extractor."""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class NoLabelledDocsAvailableError(UnsupportedLMEvalBenchmarkError):
|
|
25
|
+
"""
|
|
26
|
+
Raised when no labeled documents can be found for a given lm-eval task.
|
|
27
|
+
|
|
28
|
+
This typically indicates the task does not expose any of:
|
|
29
|
+
validation/test/training/fewshot docs, nor sufficient dataset metadata
|
|
30
|
+
to load a split directly.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
class LMEvalBenchmarkExtractor(ABC):
|
|
34
|
+
"""
|
|
35
|
+
Abstract base class for lm-eval benchmark-specific extractors.
|
|
36
|
+
|
|
37
|
+
Subclasses should implement :meth:'extract_contrastive_pairs' to transform
|
|
38
|
+
task documents into a list of :class:'ContrastivePair' instances.
|
|
39
|
+
|
|
40
|
+
Utility methods are provided to load the most appropriate labeled documents
|
|
41
|
+
from a task, with a clear order of preference and a robust dataset fallback.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def extract_contrastive_pairs(
|
|
46
|
+
self,
|
|
47
|
+
lm_eval_task_data: ConfigurableTask,
|
|
48
|
+
limit: int | None = None,
|
|
49
|
+
) -> list[ContrastivePair]:
|
|
50
|
+
"""
|
|
51
|
+
Extract contrastive pairs from the provided lm-eval task.
|
|
52
|
+
|
|
53
|
+
arguments:
|
|
54
|
+
lm_eval_task_data:
|
|
55
|
+
An lm-eval task instance.
|
|
56
|
+
limit:
|
|
57
|
+
Optional upper bound on the number of pairs to return.
|
|
58
|
+
Values <= 0 are treated as "no limit".
|
|
59
|
+
|
|
60
|
+
returns:
|
|
61
|
+
A list of :class:'ContrastivePair'.
|
|
62
|
+
"""
|
|
63
|
+
raise NotImplementedError
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
def load_docs(
|
|
68
|
+
cls,
|
|
69
|
+
lm_eval_task_data: ConfigurableTask,
|
|
70
|
+
limit: int | None = None,
|
|
71
|
+
) -> list[dict[str, Any]]:
|
|
72
|
+
"""
|
|
73
|
+
Load labeled documents from the most appropriate split with a clear
|
|
74
|
+
preference order:
|
|
75
|
+
|
|
76
|
+
validation → test → train → fewshot
|
|
77
|
+
|
|
78
|
+
If none are available, attempts a dataset fallback using
|
|
79
|
+
'datasets.load_dataset' with the task's declared metadata
|
|
80
|
+
(e.g., 'dataset_path'/'dataset_name', 'dataset_config_name',
|
|
81
|
+
and 'fewshot_split').
|
|
82
|
+
|
|
83
|
+
arguments:
|
|
84
|
+
lm_eval_task_data:
|
|
85
|
+
Task object from lm-eval.
|
|
86
|
+
limit:
|
|
87
|
+
Optional maximum number of documents to return.
|
|
88
|
+
Values <= 0 are treated as "no limit".
|
|
89
|
+
|
|
90
|
+
returns:
|
|
91
|
+
A list of document dictionaries.
|
|
92
|
+
|
|
93
|
+
raises:
|
|
94
|
+
NoLabelledDocsAvailableError:
|
|
95
|
+
If no labeled documents are available.
|
|
96
|
+
RuntimeError:
|
|
97
|
+
If a dataset fallback is attempted and fails to load.
|
|
98
|
+
"""
|
|
99
|
+
max_items = cls._normalize_limit(limit)
|
|
100
|
+
|
|
101
|
+
preferred_sources: Sequence[tuple[str, str]] = (
|
|
102
|
+
("has_validation_docs", "validation_docs"),
|
|
103
|
+
("has_test_docs", "test_docs"),
|
|
104
|
+
("has_training_docs", "training_docs"),
|
|
105
|
+
("has_fewshot_docs", "fewshot_docs"),
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
for has_method, docs_method in preferred_sources:
|
|
109
|
+
if cls._has_true(lm_eval_task_data, has_method) and cls._has_callable(
|
|
110
|
+
lm_eval_task_data, docs_method
|
|
111
|
+
):
|
|
112
|
+
docs_iter = getattr(lm_eval_task_data, docs_method)()
|
|
113
|
+
docs_list = cls._coerce_docs_to_dicts(docs_iter, max_items)
|
|
114
|
+
if docs_list:
|
|
115
|
+
return docs_list
|
|
116
|
+
|
|
117
|
+
# Fallback to dataset split (common for tasks relying on fewshot_split).
|
|
118
|
+
docs_list = cls._fallback_load_from_dataset(lm_eval_task_data, max_items)
|
|
119
|
+
if docs_list:
|
|
120
|
+
return docs_list
|
|
121
|
+
|
|
122
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
123
|
+
raise NoLabelledDocsAvailableError(
|
|
124
|
+
f"No labeled documents are available for task '{task_name}'. "
|
|
125
|
+
"The task does not expose validation/test/train/fewshot docs, "
|
|
126
|
+
"and no usable dataset metadata was found for a fallback load.\n\n"
|
|
127
|
+
"Tip: Ensure your task implements at least one of the doc getters "
|
|
128
|
+
"(validation_docs/test_docs/training_docs/fewshot_docs), or that it "
|
|
129
|
+
"declares dataset metadata (dataset_path or dataset_name, "
|
|
130
|
+
"dataset_config_name, and fewshot_split) so a split can be loaded."
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
@staticmethod
|
|
134
|
+
def _normalize_limit(limit: int | None) -> int | None:
|
|
135
|
+
"""
|
|
136
|
+
Normalize limit semantics:
|
|
137
|
+
- None → None (unbounded)
|
|
138
|
+
- <= 0 → None (unbounded)
|
|
139
|
+
- > 0 → limit
|
|
140
|
+
"""
|
|
141
|
+
if limit is None or limit <= 0:
|
|
142
|
+
return None
|
|
143
|
+
return int(limit)
|
|
144
|
+
|
|
145
|
+
@staticmethod
|
|
146
|
+
def _has_callable(obj: Any, name: str) -> bool:
|
|
147
|
+
"""Return True if obj has a callable attribute with the given name."""
|
|
148
|
+
return hasattr(obj, name) and callable(getattr(obj, name))
|
|
149
|
+
|
|
150
|
+
@staticmethod
|
|
151
|
+
def _has_true(obj: Any, name: str) -> bool:
|
|
152
|
+
"""Return True if obj has an attribute that evaluates to True when called or read."""
|
|
153
|
+
attr = getattr(obj, name, None)
|
|
154
|
+
try:
|
|
155
|
+
return bool(attr() if callable(attr) else attr)
|
|
156
|
+
except Exception: # pragma: no cover (defensive)
|
|
157
|
+
return False
|
|
158
|
+
|
|
159
|
+
@classmethod
|
|
160
|
+
def _coerce_docs_to_dicts(
|
|
161
|
+
cls,
|
|
162
|
+
docs_iter: Iterable[Any] | None,
|
|
163
|
+
max_items: int | None,
|
|
164
|
+
) -> list[dict[str, Any]]:
|
|
165
|
+
"""
|
|
166
|
+
Materialize an iterable of docs into a list of dictionaries,
|
|
167
|
+
applying an optional limit.
|
|
168
|
+
"""
|
|
169
|
+
if docs_iter is None:
|
|
170
|
+
return []
|
|
171
|
+
|
|
172
|
+
out: list[dict[str, Any]] = []
|
|
173
|
+
for idx, item in enumerate(docs_iter):
|
|
174
|
+
if max_items is not None and idx >= max_items:
|
|
175
|
+
break
|
|
176
|
+
if isinstance(item, Mapping):
|
|
177
|
+
out.append(dict(item))
|
|
178
|
+
else:
|
|
179
|
+
try:
|
|
180
|
+
out.append(dict(item))
|
|
181
|
+
except Exception as exc:
|
|
182
|
+
raise TypeError(
|
|
183
|
+
"Expected each document to be a mapping-like object that can "
|
|
184
|
+
"be converted to dict. Got type "
|
|
185
|
+
f"{type(item).__name__} with value {item!r}"
|
|
186
|
+
) from exc
|
|
187
|
+
return out
|
|
188
|
+
|
|
189
|
+
@classmethod
|
|
190
|
+
def _fallback_load_from_dataset(
|
|
191
|
+
cls,
|
|
192
|
+
lm_eval_task_data: ConfigurableTask,
|
|
193
|
+
max_items: int | None,
|
|
194
|
+
) -> list[dict[str, Any]]:
|
|
195
|
+
"""
|
|
196
|
+
Attempt to load documents via datasets.load_dataset using the task's
|
|
197
|
+
declared metadata. We prefer 'fewshot_split' if present, since this is
|
|
198
|
+
a common pattern for tasks like (M)MMLU.
|
|
199
|
+
|
|
200
|
+
returns:
|
|
201
|
+
A possibly empty list of docs.
|
|
202
|
+
"""
|
|
203
|
+
dataset_name = getattr(lm_eval_task_data, "dataset_path", None) or getattr(
|
|
204
|
+
lm_eval_task_data, "dataset_name", None
|
|
205
|
+
)
|
|
206
|
+
dataset_config = getattr(lm_eval_task_data, "dataset_config_name", None)
|
|
207
|
+
dataset_split = getattr(lm_eval_task_data, "fewshot_split", None)
|
|
208
|
+
|
|
209
|
+
if not dataset_name or not dataset_split:
|
|
210
|
+
return []
|
|
211
|
+
|
|
212
|
+
try:
|
|
213
|
+
from datasets import load_dataset
|
|
214
|
+
except Exception as exc:
|
|
215
|
+
task_name = getattr(
|
|
216
|
+
lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__
|
|
217
|
+
)
|
|
218
|
+
raise RuntimeError(
|
|
219
|
+
f"Task '{task_name}' specifies dataset metadata but "
|
|
220
|
+
"the 'datasets' library is not available. "
|
|
221
|
+
"Install it via 'pip install datasets' to enable fallback loading."
|
|
222
|
+
) from exc
|
|
223
|
+
|
|
224
|
+
try:
|
|
225
|
+
dataset = load_dataset(
|
|
226
|
+
dataset_name,
|
|
227
|
+
dataset_config if dataset_config else None,
|
|
228
|
+
split=dataset_split,
|
|
229
|
+
)
|
|
230
|
+
except Exception as exc:
|
|
231
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
232
|
+
raise RuntimeError(
|
|
233
|
+
f"Failed to load dataset split via fallback for task '{task_name}'. "
|
|
234
|
+
f"Arguments were: name={dataset_name!r}, config={dataset_config!r}, "
|
|
235
|
+
f"split={dataset_split!r}. Underlying error: {exc}"
|
|
236
|
+
) from exc
|
|
237
|
+
|
|
238
|
+
return cls._coerce_docs_to_dicts(dataset, max_items)
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
__all__ = [
|
|
2
|
+
"EXTRACTORS",
|
|
3
|
+
]
|
|
4
|
+
base_import: str = "wisent_guard.core.contrastive_pairs.lm_eval_pairs.lm_task_extractors."
|
|
5
|
+
EXTRACTORS: dict[str, str] = {
|
|
6
|
+
# key → "module_path:ClassName" (supports dotted attr path after ':')
|
|
7
|
+
"winogrande": f"{base_import}winogrande:WinograndeExtractor",
|
|
8
|
+
}
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Type, Union
|
|
4
|
+
import importlib
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
from wisent_guard.core.contrastive_pairs.lm_eval_pairs.atoms import (
|
|
8
|
+
LMEvalBenchmarkExtractor,
|
|
9
|
+
UnsupportedLMEvalBenchmarkError,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
from wisent_guard.core.contrastive_pairs.lm_eval_pairs.lm_extractor_manifest import EXTRACTORS as _MANIFEST
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"register_extractor",
|
|
16
|
+
"get_extractor",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
LOG = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
_REGISTRY: dict[str, Union[str, Type[LMEvalBenchmarkExtractor]]] = dict(_MANIFEST)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def register_extractor(name: str, ref: Union[str, Type[LMEvalBenchmarkExtractor]]) -> None:
|
|
25
|
+
"""
|
|
26
|
+
Register a new extractor by name.
|
|
27
|
+
arguments:
|
|
28
|
+
name:
|
|
29
|
+
Name/key for the extractor (case-insensitive).
|
|
30
|
+
ref:
|
|
31
|
+
Either a string "module_path:ClassName[.Inner]" or a subclass of
|
|
32
|
+
LMEvalBenchmarkExtractor.
|
|
33
|
+
raises:
|
|
34
|
+
ValueError:
|
|
35
|
+
If the name is empty or the string ref is malformed.
|
|
36
|
+
TypeError:
|
|
37
|
+
If the ref class does not subclass LMEvalBenchmarkExtractor.
|
|
38
|
+
|
|
39
|
+
example:
|
|
40
|
+
>>> from wisent_guard.core.contrastive_pairs.lm_eval_pairs.lm_extractor_registry import register_extractor
|
|
41
|
+
>>> from wisent_guard.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
42
|
+
>>> class MyExtractor(LMEvalBenchmarkExtractor): ...
|
|
43
|
+
>>> register_extractor("mytask", MyExtractor)
|
|
44
|
+
>>> register_extractor("mytask2", "my_module:MyExtractor")
|
|
45
|
+
"""
|
|
46
|
+
key = (name or "").strip().lower()
|
|
47
|
+
if not key:
|
|
48
|
+
raise ValueError("Extractor name/key must be a non-empty string.")
|
|
49
|
+
|
|
50
|
+
if isinstance(ref, str):
|
|
51
|
+
if ":" not in ref:
|
|
52
|
+
raise ValueError("String ref must be 'module_path:ClassName[.Inner]'.")
|
|
53
|
+
_REGISTRY[key] = ref
|
|
54
|
+
return
|
|
55
|
+
|
|
56
|
+
if not issubclass(ref, LMEvalBenchmarkExtractor):
|
|
57
|
+
raise TypeError(f"{getattr(ref, '__name__', ref)!r} must subclass LMEvalBenchmarkExtractor")
|
|
58
|
+
|
|
59
|
+
_REGISTRY[key] = ref
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def get_extractor(task_name: str) -> LMEvalBenchmarkExtractor:
|
|
63
|
+
"""
|
|
64
|
+
Retrieve a registered extractor by task name.
|
|
65
|
+
|
|
66
|
+
arguments:
|
|
67
|
+
task_name:
|
|
68
|
+
Name of the lm-eval benchmark/task (e.g., "winogrande").
|
|
69
|
+
Case-insensitive. Exact match only.
|
|
70
|
+
|
|
71
|
+
returns:
|
|
72
|
+
An instance of the corresponding LMEvalBenchmarkExtractor subclass.
|
|
73
|
+
|
|
74
|
+
raises:
|
|
75
|
+
UnsupportedLMEvalBenchmarkError:
|
|
76
|
+
If no extractor is registered for the given task name.
|
|
77
|
+
ImportError:
|
|
78
|
+
If the extractor class cannot be imported/resolved.
|
|
79
|
+
TypeError:
|
|
80
|
+
If the resolved class does not subclass LMEvalBenchmarkExtractor.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
key = (task_name or "").strip().lower()
|
|
84
|
+
if not key:
|
|
85
|
+
raise UnsupportedLMEvalBenchmarkError("Empty task name is not supported.")
|
|
86
|
+
|
|
87
|
+
ref = _REGISTRY.get(key)
|
|
88
|
+
if ref:
|
|
89
|
+
return _instantiate(ref)
|
|
90
|
+
|
|
91
|
+
raise UnsupportedLMEvalBenchmarkError(
|
|
92
|
+
f"No extractor registered for task '{task_name}'. "
|
|
93
|
+
f"Known: {', '.join(sorted(_REGISTRY)) or '(none)'}"
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
def _instantiate(ref: Union[str, Type[LMEvalBenchmarkExtractor]]) -> LMEvalBenchmarkExtractor:
|
|
97
|
+
"""
|
|
98
|
+
Instantiate an extractor from a string reference or class.
|
|
99
|
+
|
|
100
|
+
arguments:
|
|
101
|
+
ref:
|
|
102
|
+
Either a string "module_path:ClassName[.Inner]" or a subclass of
|
|
103
|
+
LMEvalBenchmarkExtractor.
|
|
104
|
+
|
|
105
|
+
returns:
|
|
106
|
+
An instance of the corresponding LMEvalBenchmarkExtractor subclass.
|
|
107
|
+
|
|
108
|
+
raises:
|
|
109
|
+
ImportError:
|
|
110
|
+
If the extractor class cannot be imported/resolved.
|
|
111
|
+
TypeError:
|
|
112
|
+
If the resolved class does not subclass LMEvalBenchmarkExtractor.
|
|
113
|
+
"""
|
|
114
|
+
if not isinstance(ref, str):
|
|
115
|
+
return ref()
|
|
116
|
+
|
|
117
|
+
module_path, attr_path = ref.split(":", 1)
|
|
118
|
+
try:
|
|
119
|
+
mod = importlib.import_module(module_path)
|
|
120
|
+
except Exception as exc:
|
|
121
|
+
raise ImportError(f"Cannot import module '{module_path}' for extractor '{ref}'.") from exc
|
|
122
|
+
|
|
123
|
+
obj = mod
|
|
124
|
+
for part in attr_path.split("."):
|
|
125
|
+
try:
|
|
126
|
+
obj = getattr(obj, part)
|
|
127
|
+
except AttributeError as exc:
|
|
128
|
+
raise ImportError(f"Extractor class '{attr_path}' not found in '{module_path}'.") from exc
|
|
129
|
+
|
|
130
|
+
if not isinstance(obj, type) or not issubclass(obj, LMEvalBenchmarkExtractor):
|
|
131
|
+
raise TypeError(f"Resolved object '{obj}' is not a LMEvalBenchmarkExtractor subclass.")
|
|
132
|
+
return obj()
|
|
File without changes
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from wisent_guard.core.contrastive_pairs.core.pair import ContrastivePair
|
|
6
|
+
from wisent_guard.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
|
|
7
|
+
from wisent_guard.core.contrastive_pairs.lm_eval_pairs.atoms import LMEvalBenchmarkExtractor
|
|
8
|
+
from wisent_guard.cli.cli_logger import setup_logger, bind
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from lm_eval.api.task import ConfigurableTask
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = ["WinograndeExtractor"]
|
|
15
|
+
_LOG = setup_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class WinograndeExtractor(LMEvalBenchmarkExtractor):
|
|
19
|
+
"""Extractor for the Winogrande benchmark."""
|
|
20
|
+
|
|
21
|
+
def extract_contrastive_pairs(
|
|
22
|
+
self,
|
|
23
|
+
lm_eval_task_data: ConfigurableTask,
|
|
24
|
+
limit: int | None = None,
|
|
25
|
+
) -> list[ContrastivePair]:
|
|
26
|
+
"""
|
|
27
|
+
Build contrastive pairs from Winogrande docs.
|
|
28
|
+
|
|
29
|
+
Winogrande schema:
|
|
30
|
+
- sentence: str (contains a blank)
|
|
31
|
+
- option1, option2: str
|
|
32
|
+
- answer: "1" or "2" (sometimes int-like)
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
lm_eval_task_data: lm-eval task instance for Winogrande.
|
|
36
|
+
limit: Optional maximum number of pairs to produce.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
A list of ContrastivePair objects.
|
|
40
|
+
"""
|
|
41
|
+
log = bind(_LOG, task=getattr(lm_eval_task_data, "NAME", "unknown"))
|
|
42
|
+
|
|
43
|
+
max_items = self._normalize_limit(limit)
|
|
44
|
+
docs = self.load_docs(lm_eval_task_data, max_items)
|
|
45
|
+
|
|
46
|
+
pairs: list[ContrastivePair] = []
|
|
47
|
+
|
|
48
|
+
log.info("Extracting contrastive pairs", extra={"doc_count": len(docs)})
|
|
49
|
+
|
|
50
|
+
for doc in docs:
|
|
51
|
+
pair = self._extract_pair_from_doc(doc)
|
|
52
|
+
if pair is not None:
|
|
53
|
+
pairs.append(pair)
|
|
54
|
+
if max_items is not None and len(pairs) >= max_items:
|
|
55
|
+
break
|
|
56
|
+
|
|
57
|
+
if not pairs:
|
|
58
|
+
task_name = getattr(lm_eval_task_data, "NAME", type(lm_eval_task_data).__name__)
|
|
59
|
+
log.warning("No valid Winogrande pairs extracted", extra={"task": task_name})
|
|
60
|
+
|
|
61
|
+
return pairs
|
|
62
|
+
|
|
63
|
+
def _extract_pair_from_doc(self, doc: dict[str, Any]) -> ContrastivePair | None:
|
|
64
|
+
"""
|
|
65
|
+
Convert a single Winogrande doc into a ContrastivePair, if possible.
|
|
66
|
+
Returns None when required fields are missing or malformed.
|
|
67
|
+
"""
|
|
68
|
+
log = bind(_LOG, doc_id=doc.get("id", "unknown"))
|
|
69
|
+
|
|
70
|
+
try:
|
|
71
|
+
sentence = str(doc.get("sentence", "")).strip()
|
|
72
|
+
option1 = str(doc.get("option1", "")).strip()
|
|
73
|
+
option2 = str(doc.get("option2", "")).strip()
|
|
74
|
+
|
|
75
|
+
raw_answer = doc.get("answer", "")
|
|
76
|
+
answer = str(raw_answer).strip()
|
|
77
|
+
|
|
78
|
+
if not sentence or not option1 or not option2 or answer not in {"1", "2"}:
|
|
79
|
+
log.debug(
|
|
80
|
+
"Skipping doc due to missing/invalid fields",
|
|
81
|
+
extra={"doc": doc},
|
|
82
|
+
)
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
question = f"Complete the sentence: {sentence}"
|
|
86
|
+
formatted_question = f"{question}\nA. {option1}\nB. {option2}"
|
|
87
|
+
|
|
88
|
+
correct = option1 if answer == "1" else option2
|
|
89
|
+
incorrect = option2 if answer == "1" else option1
|
|
90
|
+
|
|
91
|
+
metadata = {
|
|
92
|
+
"label": "winogrande",
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
return self._build_pair(
|
|
96
|
+
question=formatted_question,
|
|
97
|
+
correct=correct,
|
|
98
|
+
incorrect=incorrect,
|
|
99
|
+
metadata=metadata,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
except Exception as exc:
|
|
103
|
+
log.error("Error extracting pair from doc", exc_info=exc, extra={"doc": doc})
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
@staticmethod
|
|
107
|
+
def _build_pair(
|
|
108
|
+
question: str,
|
|
109
|
+
correct: str,
|
|
110
|
+
incorrect: str,
|
|
111
|
+
metadata: dict[str, Any] | None = None,
|
|
112
|
+
) -> ContrastivePair:
|
|
113
|
+
positive_response = PositiveResponse(model_response=correct)
|
|
114
|
+
negative_response = NegativeResponse(model_response=incorrect)
|
|
115
|
+
return ContrastivePair(prompt=question, positive_response=positive_response, negative_response=negative_response, label=metadata.get("label"))
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from wisent_guard.core.contrastive_pairs.lm_eval_pairs.lm_extractor_registry import get_extractor
|
|
6
|
+
from wisent_guard.cli.cli_logger import setup_logger, bind
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from lm_eval.api.task import ConfigurableTask
|
|
10
|
+
from wisent_guard.core.contrastive_pairs.core.pair import ContrastivePair
|
|
11
|
+
|
|
12
|
+
__all__ = ["build_contrastive_pairs"]
|
|
13
|
+
_LOG = setup_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def lm_build_contrastive_pairs(
|
|
17
|
+
task_name: str,
|
|
18
|
+
lm_eval_task: ConfigurableTask,
|
|
19
|
+
limit: int | None = None,
|
|
20
|
+
) -> list[ContrastivePair]:
|
|
21
|
+
"""
|
|
22
|
+
Resolve the task's extractor (lazy-loaded) and return contrastive pairs.
|
|
23
|
+
|
|
24
|
+
arguments:
|
|
25
|
+
task_name:
|
|
26
|
+
Name of the lm-eval benchmark/task (e.g., "winogrande").
|
|
27
|
+
lm_eval_task:
|
|
28
|
+
An lm-eval task instance.
|
|
29
|
+
limit:
|
|
30
|
+
Optional upper bound on the number of pairs to return.
|
|
31
|
+
Values <= 0 are treated as "no limit".
|
|
32
|
+
|
|
33
|
+
returns:
|
|
34
|
+
A list of ContrastivePair objects.
|
|
35
|
+
"""
|
|
36
|
+
log = bind(_LOG, task=task_name or "unknown")
|
|
37
|
+
log.info("Building contrastive pairs", extra={"limit": limit})
|
|
38
|
+
|
|
39
|
+
# 1) Get extractor instance by name (exact or longest-prefix)
|
|
40
|
+
extractor = get_extractor(task_name)
|
|
41
|
+
|
|
42
|
+
log.info("Using extractor", extra={"extractor": extractor.__class__.__name__})
|
|
43
|
+
|
|
44
|
+
# 2) Normalize limit (<=0 → None)
|
|
45
|
+
max_items = None if (limit is None or limit <= 0) else int(limit)
|
|
46
|
+
|
|
47
|
+
log.info("Extracting contrastive pairs", extra={"max_items": max_items})
|
|
48
|
+
|
|
49
|
+
# 3) Delegate: extractor loads docs and builds pairs
|
|
50
|
+
return extractor.extract_contrastive_pairs(lm_eval_task, limit=max_items)
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
import inspect
|
|
5
|
+
from typing import Any, Dict, Type
|
|
6
|
+
|
|
7
|
+
from typing import TypedDict, Mapping
|
|
8
|
+
from lm_eval.api.task import ConfigurableTask
|
|
9
|
+
from wisent_guard.core.contrastive_pairs.core.set import ContrastivePairSet
|
|
10
|
+
|
|
11
|
+
__all__ = ["DataLoaderError", "BaseDataLoader"]
|
|
12
|
+
|
|
13
|
+
class LoadDataResult(TypedDict):
|
|
14
|
+
"""
|
|
15
|
+
Structured output from a data loader used for training and evaluation.
|
|
16
|
+
|
|
17
|
+
attributes:
|
|
18
|
+
train_qa_pairs:
|
|
19
|
+
The training set of question-answer pairs.
|
|
20
|
+
test_qa_pairs:
|
|
21
|
+
The test set of question-answer pairs.
|
|
22
|
+
task_type:
|
|
23
|
+
The high-level task category (e.g., "classification").
|
|
24
|
+
lm_task_data:
|
|
25
|
+
Tasks in the 'lm_eval' repository format, if applicable.
|
|
26
|
+
|
|
27
|
+
When training/evaluating steering vectors with 'lm_eval', that
|
|
28
|
+
library is responsible for downloading and preprocessing the data,
|
|
29
|
+
and it provides the evaluation function that compares the steered
|
|
30
|
+
model to the baseline, see: https://github.com/EleutherAI/lm-evaluation-harness.
|
|
31
|
+
For custom data loaders, this is 'None'.
|
|
32
|
+
"""
|
|
33
|
+
train_qa_pairs: ContrastivePairSet
|
|
34
|
+
test_qa_pairs: ContrastivePairSet
|
|
35
|
+
task_type: str
|
|
36
|
+
lm_task_data: Mapping[str, ConfigurableTask] | ConfigurableTask | None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class DataLoaderError(RuntimeError):
|
|
40
|
+
"""Raised when a data loader cannot complete loading."""
|
|
41
|
+
|
|
42
|
+
class BaseDataLoader(ABC):
|
|
43
|
+
"""Abstract data loader base. Concrete subclasses auto-register on import."""
|
|
44
|
+
name: str = "base"
|
|
45
|
+
description: str = "Abstract data loader"
|
|
46
|
+
|
|
47
|
+
_REGISTRY: Dict[str, Type["BaseDataLoader"]] = {}
|
|
48
|
+
|
|
49
|
+
def __init_subclass__(cls, **kwargs):
|
|
50
|
+
super().__init_subclass__(**kwargs)
|
|
51
|
+
if cls is BaseDataLoader:
|
|
52
|
+
return
|
|
53
|
+
if inspect.isabstract(cls):
|
|
54
|
+
return
|
|
55
|
+
if not getattr(cls, "name", None):
|
|
56
|
+
raise TypeError("DataLoader subclasses must define a class attribute `name`.")
|
|
57
|
+
if cls.name in BaseDataLoader._REGISTRY:
|
|
58
|
+
raise ValueError(f"Duplicate data loader name: {cls.name!r}")
|
|
59
|
+
BaseDataLoader._REGISTRY[cls.name] = cls
|
|
60
|
+
|
|
61
|
+
def __init__(self, **kwargs: Any) -> None:
|
|
62
|
+
self.kwargs: dict[str, Any] = dict(kwargs)
|
|
63
|
+
|
|
64
|
+
@staticmethod
|
|
65
|
+
def _effective_split(split_ratio: float | None) -> float:
|
|
66
|
+
"""
|
|
67
|
+
Determine the effective split ratio, defaulting to 0.8 if None.
|
|
68
|
+
|
|
69
|
+
arguments:
|
|
70
|
+
split_ratio: Optional float in [0.0, 1.0] or None.
|
|
71
|
+
|
|
72
|
+
returns:
|
|
73
|
+
A float in [0.0, 1.0] representing the training split ratio.
|
|
74
|
+
|
|
75
|
+
raises:
|
|
76
|
+
ValueError if split_ratio is not in [0.0, 1.0].
|
|
77
|
+
"""
|
|
78
|
+
if split_ratio is None:
|
|
79
|
+
return 0.8
|
|
80
|
+
if not (0.0 <= split_ratio <= 1.0):
|
|
81
|
+
raise ValueError("split_ratio must be in [0.0, 1.0]")
|
|
82
|
+
return float(split_ratio)
|
|
83
|
+
|
|
84
|
+
@abstractmethod
|
|
85
|
+
def load(self, **kwargs: Any) -> LoadDataResult:
|
|
86
|
+
"""Return a LoadDataResult (train_qa_pairs, test_qa_pairs, task_type, lm_task_data)."""
|
|
87
|
+
raise NotImplementedError
|
|
88
|
+
|
|
89
|
+
@classmethod
|
|
90
|
+
def list_registered(cls) -> dict[str, Type["BaseDataLoader"]]:
|
|
91
|
+
return dict(cls._REGISTRY)
|
|
92
|
+
|
|
93
|
+
@classmethod
|
|
94
|
+
def get(cls, name: str) -> Type["BaseDataLoader"]:
|
|
95
|
+
try:
|
|
96
|
+
return cls._REGISTRY[name]
|
|
97
|
+
except KeyError as exc:
|
|
98
|
+
raise DataLoaderError(f"Unknown data loader: {name!r}") from exc
|
|
File without changes
|