wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.1.dist-info/METADATA +67 -0
- wisent-0.5.1.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Any, Iterable
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
from wisent_guard.core.data_loaders.core.atoms import BaseDataLoader, DataLoaderError, LoadDataResult
|
|
6
|
+
from wisent_guard.core.contrastive_pairs.core.set import ContrastivePairSet
|
|
7
|
+
from wisent_guard.core.contrastive_pairs.core.serialization import load_contrastive_pair_set
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"CustomUserDataLoader",
|
|
11
|
+
]
|
|
12
|
+
log = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
class CustomUserDataLoader(BaseDataLoader):
|
|
15
|
+
"""
|
|
16
|
+
Load a ContrastivePairSet from a JSONL file, split into train/test,
|
|
17
|
+
and optionally cap each split.
|
|
18
|
+
|
|
19
|
+
attributes:
|
|
20
|
+
name: "custom"
|
|
21
|
+
The unique name of this data loader.
|
|
22
|
+
description: "Load contrastive pairs from custom JSONL and split."
|
|
23
|
+
A brief description of this data loader.
|
|
24
|
+
"""
|
|
25
|
+
name = "custom"
|
|
26
|
+
description = "Load contrastive pairs from custom JSONL and split."
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def _shuffle_indices(n: int, seed: int | None) -> list[int]:
|
|
30
|
+
"""
|
|
31
|
+
Generate a shuffled list of indices from 0 to n-1.
|
|
32
|
+
|
|
33
|
+
arguments:
|
|
34
|
+
n: The number of indices to generate.
|
|
35
|
+
seed: Optional random seed for reproducibility.
|
|
36
|
+
|
|
37
|
+
returns:
|
|
38
|
+
A list of shuffled indices.
|
|
39
|
+
"""
|
|
40
|
+
idx = list(range(n))
|
|
41
|
+
if seed is None:
|
|
42
|
+
return idx
|
|
43
|
+
try:
|
|
44
|
+
from numpy.random import default_rng
|
|
45
|
+
except Exception:
|
|
46
|
+
import random
|
|
47
|
+
rnd = random.Random(seed)
|
|
48
|
+
rnd.shuffle(idx)
|
|
49
|
+
return idx
|
|
50
|
+
else:
|
|
51
|
+
return default_rng(seed).permutation(n).tolist()
|
|
52
|
+
|
|
53
|
+
def load(
|
|
54
|
+
self,
|
|
55
|
+
path: str,
|
|
56
|
+
split_ratio: float | None = None,
|
|
57
|
+
seed: int | None = None,
|
|
58
|
+
training_limit: int | None = None,
|
|
59
|
+
testing_limit: int | None = None,
|
|
60
|
+
**_: Any,
|
|
61
|
+
) -> LoadDataResult:
|
|
62
|
+
"""
|
|
63
|
+
Load contrastive pairs from a JSONL file, split into train/test sets,
|
|
64
|
+
and optionally limit the number of pairs in each set.
|
|
65
|
+
|
|
66
|
+
arguments:
|
|
67
|
+
path:
|
|
68
|
+
Path to the JSONL file containing contrastive pairs.
|
|
69
|
+
split_ratio:
|
|
70
|
+
Float in [0.0, 1.0] representing the proportion of data to use for training.
|
|
71
|
+
Defaults to 0.8 if None.
|
|
72
|
+
seed:
|
|
73
|
+
Optional random seed for shuffling the data before splitting.
|
|
74
|
+
training_limit:
|
|
75
|
+
Optional maximum number of training pairs to return.
|
|
76
|
+
testing_limit:
|
|
77
|
+
Optional maximum number of testing pairs to return.
|
|
78
|
+
**_:
|
|
79
|
+
Additional keyword arguments (ignored).
|
|
80
|
+
returns:
|
|
81
|
+
LoadDataResult with train/test ContrastivePairSets and metadata.
|
|
82
|
+
|
|
83
|
+
raises:
|
|
84
|
+
DataLoaderError if loading or processing fails.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
if not path:
|
|
88
|
+
raise DataLoaderError("'path' is required for custom loader.")
|
|
89
|
+
|
|
90
|
+
split = self._effective_split(split_ratio)
|
|
91
|
+
data: ContrastivePairSet = load_contrastive_pair_set(path)
|
|
92
|
+
log.info("Loaded custom data: %r", data)
|
|
93
|
+
|
|
94
|
+
if not data.pairs:
|
|
95
|
+
raise DataLoaderError("No contrastive pairs found in the input file.")
|
|
96
|
+
|
|
97
|
+
n = len(data.pairs)
|
|
98
|
+
idx = self._shuffle_indices(n, seed)
|
|
99
|
+
split_at = int(n * split)
|
|
100
|
+
|
|
101
|
+
train_pairs = [data.pairs[i] for i in idx[:split_at]]
|
|
102
|
+
test_pairs = [data.pairs[i] for i in idx[split_at:]]
|
|
103
|
+
|
|
104
|
+
if training_limit is not None:
|
|
105
|
+
train_pairs = train_pairs[: max(0, int(training_limit))]
|
|
106
|
+
if testing_limit is not None:
|
|
107
|
+
test_pairs = test_pairs[: max(0, int(testing_limit))]
|
|
108
|
+
|
|
109
|
+
train_set = ContrastivePairSet(name=f"{data.name}_train", pairs=train_pairs, task_type=data.task_type)
|
|
110
|
+
test_set = ContrastivePairSet(name=f"{data.name}_test", pairs=test_pairs, task_type=data.task_type)
|
|
111
|
+
|
|
112
|
+
train_set.validate()
|
|
113
|
+
test_set.validate()
|
|
114
|
+
|
|
115
|
+
return LoadDataResult(
|
|
116
|
+
train_qa_pairs=train_set,
|
|
117
|
+
test_qa_pairs=test_set,
|
|
118
|
+
task_type=data.task_type or "custom",
|
|
119
|
+
lm_task_data=None,
|
|
120
|
+
)
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Any, TYPE_CHECKING
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
from wisent_guard.core.data_loaders.core.atoms import BaseDataLoader, DataLoaderError, LoadDataResult
|
|
6
|
+
from wisent_guard.core.contrastive_pairs.core.pair import ContrastivePair
|
|
7
|
+
from wisent_guard.core.contrastive_pairs.core.set import ContrastivePairSet
|
|
8
|
+
from lm_eval.tasks import get_task_dict
|
|
9
|
+
from lm_eval.tasks import TaskManager as LMTaskManager
|
|
10
|
+
from wisent_guard.core.contrastive_pairs.lm_eval_pairs.lm_task_pairs_generation import (
|
|
11
|
+
lm_build_contrastive_pairs,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from lm_eval.api.task import ConfigurableTask
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"LMEvalDataLoader",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
log = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
class LMEvalDataLoader(BaseDataLoader):
|
|
24
|
+
"""
|
|
25
|
+
Load contrastive pairs from a single lm-evaluation-harness task via `load_lm_eval_task`,
|
|
26
|
+
split into train/test, and return a canonical LoadDataResult.
|
|
27
|
+
"""
|
|
28
|
+
name = "lm_eval"
|
|
29
|
+
description = "Load from a single lm-eval task."
|
|
30
|
+
|
|
31
|
+
def _load_one_task(
|
|
32
|
+
self,
|
|
33
|
+
task_name: str,
|
|
34
|
+
split_ratio: float,
|
|
35
|
+
seed: int,
|
|
36
|
+
limit: int | None,
|
|
37
|
+
training_limit: int | None,
|
|
38
|
+
testing_limit: int | None,
|
|
39
|
+
) -> LoadDataResult:
|
|
40
|
+
"""
|
|
41
|
+
Load a single lm-eval task by name, convert to contrastive pairs,
|
|
42
|
+
split into train/test, and return a LoadDataResult.
|
|
43
|
+
|
|
44
|
+
arguments:
|
|
45
|
+
task_name: The name of the lm-eval task to load.
|
|
46
|
+
split_ratio: The fraction of data to use for training (between 0 and 1).
|
|
47
|
+
seed: Random seed for shuffling/splitting.
|
|
48
|
+
limit: Optional limit on total number of pairs to load.
|
|
49
|
+
training_limit: Optional limit on number of training pairs.
|
|
50
|
+
testing_limit: Optional limit on number of testing pairs.
|
|
51
|
+
|
|
52
|
+
returns:
|
|
53
|
+
A LoadDataResult containing train/test pairs and task info.
|
|
54
|
+
|
|
55
|
+
raises:
|
|
56
|
+
DataLoaderError if the task cannot be found or if splits are empty.
|
|
57
|
+
ValueError if split_ratio is not in [0.0, 1.0].
|
|
58
|
+
NotImplementedError if load_lm_eval_task is not implemented.
|
|
59
|
+
|
|
60
|
+
note:
|
|
61
|
+
This loader only supports single tasks, not mixtures. To load mixtures,
|
|
62
|
+
use a custom data loader or extend this one."""
|
|
63
|
+
loaded = self.load_lm_eval_task(task_name)
|
|
64
|
+
|
|
65
|
+
if isinstance(loaded, dict):
|
|
66
|
+
if len(loaded) != 1:
|
|
67
|
+
keys = ", ".join(sorted(loaded.keys()))
|
|
68
|
+
raise DataLoaderError(
|
|
69
|
+
f"Task '{task_name}' returned {len(loaded)} subtasks ({keys}). "
|
|
70
|
+
"Specify an explicit subtask, e.g. 'benchmark/subtask'."
|
|
71
|
+
)
|
|
72
|
+
(subname, task_obj), = loaded.items()
|
|
73
|
+
pairs_task_name = subname
|
|
74
|
+
else:
|
|
75
|
+
task_obj = loaded
|
|
76
|
+
pairs_task_name = task_name
|
|
77
|
+
|
|
78
|
+
pairs = lm_build_contrastive_pairs(
|
|
79
|
+
task_name=pairs_task_name,
|
|
80
|
+
lm_eval_task=task_obj,
|
|
81
|
+
limit=limit,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
train_pairs, test_pairs = self._split_pairs(
|
|
85
|
+
pairs, split_ratio, seed, training_limit, testing_limit
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
if not train_pairs or not test_pairs:
|
|
89
|
+
raise DataLoaderError("One of the splits is empty after splitting.")
|
|
90
|
+
|
|
91
|
+
train_set = ContrastivePairSet("lm_eval_train", train_pairs, task_type=task_name)
|
|
92
|
+
test_set = ContrastivePairSet("lm_eval_test", test_pairs, task_type=task_name)
|
|
93
|
+
|
|
94
|
+
train_set.validate()
|
|
95
|
+
test_set.validate()
|
|
96
|
+
|
|
97
|
+
return LoadDataResult(
|
|
98
|
+
train_qa_pairs=train_set,
|
|
99
|
+
test_qa_pairs=test_set,
|
|
100
|
+
task_type=task_name,
|
|
101
|
+
lm_task_data=task_obj,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
def load(
|
|
105
|
+
self,
|
|
106
|
+
task: str,
|
|
107
|
+
split_ratio: float | None = None,
|
|
108
|
+
seed: int = 42,
|
|
109
|
+
limit: int | None = None,
|
|
110
|
+
training_limit: int | None = None,
|
|
111
|
+
testing_limit: int | None = None,
|
|
112
|
+
**_: Any,
|
|
113
|
+
) -> LoadDataResult:
|
|
114
|
+
"""
|
|
115
|
+
Load contrastive pairs from a single lm-eval-harness task, split into train/test sets.
|
|
116
|
+
arguments:
|
|
117
|
+
task:
|
|
118
|
+
The name of the lm-eval task to load (e.g., "winogrande", "hellaswag").
|
|
119
|
+
Must be a single task, not a mixture.
|
|
120
|
+
split_ratio:
|
|
121
|
+
Float in [0.0, 1.0] representing the proportion of data to use for training.
|
|
122
|
+
Defaults to 0.8 if None.
|
|
123
|
+
seed:
|
|
124
|
+
Random seed for shuffling the data before splitting.
|
|
125
|
+
limit:
|
|
126
|
+
Optional maximum number of total pairs to load from the task.
|
|
127
|
+
training_limit:
|
|
128
|
+
Optional maximum number of training pairs to return.
|
|
129
|
+
testing_limit:
|
|
130
|
+
Optional maximum number of testing pairs to return.
|
|
131
|
+
**_:
|
|
132
|
+
Additional keyword arguments (ignored).
|
|
133
|
+
|
|
134
|
+
returns:
|
|
135
|
+
LoadDataResult with train/test ContrastivePairSets and metadata.
|
|
136
|
+
|
|
137
|
+
raises:
|
|
138
|
+
DataLoaderError if loading or processing fails.
|
|
139
|
+
ValueError if split_ratio is not in [0.0, 1.0].
|
|
140
|
+
NotImplementedError if load_lm_eval_task is not implemented.
|
|
141
|
+
"""
|
|
142
|
+
split = self._effective_split(split_ratio)
|
|
143
|
+
|
|
144
|
+
# Single-task path only
|
|
145
|
+
return self._load_one_task(
|
|
146
|
+
task_name=str(task),
|
|
147
|
+
split_ratio=split,
|
|
148
|
+
seed=seed,
|
|
149
|
+
limit=limit,
|
|
150
|
+
training_limit=training_limit,
|
|
151
|
+
testing_limit=testing_limit,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
@staticmethod
|
|
155
|
+
def load_lm_eval_task(task_name: str) -> ConfigurableTask | dict[str, ConfigurableTask]:
|
|
156
|
+
"""
|
|
157
|
+
Load a single lm-eval-harness task by name.
|
|
158
|
+
|
|
159
|
+
arguments:
|
|
160
|
+
task_name: The name of the lm-eval task to load.
|
|
161
|
+
|
|
162
|
+
returns:
|
|
163
|
+
A ConfigurableTask instance or a dict of subtask name to ConfigurableTask.
|
|
164
|
+
|
|
165
|
+
raises:
|
|
166
|
+
DataLoaderError if the task cannot be found.
|
|
167
|
+
"""
|
|
168
|
+
task_manager = LMTaskManager()
|
|
169
|
+
task_manager.initialize_tasks()
|
|
170
|
+
|
|
171
|
+
task_dict = get_task_dict([task_name], task_manager=task_manager)
|
|
172
|
+
if task_name in task_dict:
|
|
173
|
+
return task_dict[task_name]
|
|
174
|
+
raise DataLoaderError(f"lm-eval task '{task_name}' not found.")
|
|
175
|
+
|
|
176
|
+
def _split_pairs(
|
|
177
|
+
self,
|
|
178
|
+
pairs: list[ContrastivePair],
|
|
179
|
+
split_ratio: float,
|
|
180
|
+
seed: int,
|
|
181
|
+
training_limit: int | None,
|
|
182
|
+
testing_limit: int | None,
|
|
183
|
+
) -> tuple[list[ContrastivePair], list[ContrastivePair]]:
|
|
184
|
+
"""
|
|
185
|
+
Split a list of ContrastivePairs into train/test sets.
|
|
186
|
+
|
|
187
|
+
arguments:
|
|
188
|
+
pairs: List of ContrastivePair to split.
|
|
189
|
+
split_ratio: Float in [0.0, 1.0] for the training set proportion.
|
|
190
|
+
seed: Random seed for shuffling.
|
|
191
|
+
training_limit: Optional max number of training pairs.
|
|
192
|
+
testing_limit: Optional max number of testing pairs.
|
|
193
|
+
|
|
194
|
+
returns:
|
|
195
|
+
A tuple of (train_pairs, test_pairs).
|
|
196
|
+
raises:
|
|
197
|
+
ValueError if split_ratio is not in [0.0, 1.0].
|
|
198
|
+
"""
|
|
199
|
+
if not pairs:
|
|
200
|
+
return [], []
|
|
201
|
+
from numpy.random import default_rng
|
|
202
|
+
|
|
203
|
+
idx = list(range(len(pairs)))
|
|
204
|
+
default_rng(seed).shuffle(idx)
|
|
205
|
+
cut = int(len(pairs) * split_ratio)
|
|
206
|
+
train_idx = set(idx[:cut])
|
|
207
|
+
|
|
208
|
+
train_pairs: list[ContrastivePair] = []
|
|
209
|
+
test_pairs: list[ContrastivePair] = []
|
|
210
|
+
for i in idx:
|
|
211
|
+
(train_pairs if i in train_idx else test_pairs).append(pairs[i])
|
|
212
|
+
|
|
213
|
+
if training_limit and training_limit > 0:
|
|
214
|
+
train_pairs = train_pairs[:training_limit]
|
|
215
|
+
if testing_limit and testing_limit > 0:
|
|
216
|
+
test_pairs = test_pairs[:testing_limit]
|
|
217
|
+
|
|
218
|
+
return train_pairs, test_pairs
|
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Detection handling module for wisent-guard.
|
|
3
|
+
|
|
4
|
+
This module provides different strategies for handling responses that have been
|
|
5
|
+
detected as problematic (hallucinations, harmful content, bias, etc.).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from enum import Enum
|
|
9
|
+
from typing import Optional, Callable, Dict, Any
|
|
10
|
+
import logging
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DetectionAction(Enum):
|
|
16
|
+
"""Actions to take when problematic content is detected."""
|
|
17
|
+
PASS_THROUGH = "pass_through" # Output the response as-is
|
|
18
|
+
REPLACE_WITH_PLACEHOLDER = "replace_with_placeholder" # Replace with safe message
|
|
19
|
+
REGENERATE_UNTIL_SAFE = "regenerate_until_safe" # Keep regenerating until safe
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DetectionHandler:
|
|
23
|
+
"""
|
|
24
|
+
Handles responses when problematic content is detected.
|
|
25
|
+
|
|
26
|
+
Provides configurable strategies for dealing with detected issues like
|
|
27
|
+
hallucinations, harmful content, bias, etc.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
action: DetectionAction = DetectionAction.REPLACE_WITH_PLACEHOLDER,
|
|
33
|
+
placeholder_message: Optional[str] = None,
|
|
34
|
+
max_regeneration_attempts: int = 3,
|
|
35
|
+
custom_placeholder_generator: Optional[Callable[[str, str], str]] = None,
|
|
36
|
+
log_detections: bool = True
|
|
37
|
+
):
|
|
38
|
+
"""
|
|
39
|
+
Initialize the detection handler.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
action: What action to take when detection occurs
|
|
43
|
+
placeholder_message: Custom placeholder message (if None, uses default)
|
|
44
|
+
max_regeneration_attempts: Maximum times to regenerate before giving up
|
|
45
|
+
custom_placeholder_generator: Function to generate custom placeholders
|
|
46
|
+
log_detections: Whether to log detection events
|
|
47
|
+
"""
|
|
48
|
+
self.action = action
|
|
49
|
+
self.placeholder_message = placeholder_message
|
|
50
|
+
self.max_regeneration_attempts = max_regeneration_attempts
|
|
51
|
+
self.custom_placeholder_generator = custom_placeholder_generator
|
|
52
|
+
self.log_detections = log_detections
|
|
53
|
+
|
|
54
|
+
# Default placeholder messages for different detection types
|
|
55
|
+
self.default_placeholders = {
|
|
56
|
+
"hallucination": "I apologize, but I may not have accurate information about this topic. Please verify any factual claims from reliable sources.",
|
|
57
|
+
"harmful_content": "I cannot provide information that could be harmful or dangerous. Please ask about something else I can help with safely.",
|
|
58
|
+
"bias": "I want to avoid potentially biased responses. Let me try to provide a more balanced perspective on this topic.",
|
|
59
|
+
"personal_info": "I cannot generate or discuss personal information. Please ask about general topics instead.",
|
|
60
|
+
"scheming": "I cannot provide advice on deceptive or manipulative behavior. Let me help you with ethical approaches instead.",
|
|
61
|
+
"bad_code": "I cannot provide code examples that may contain security vulnerabilities. Let me suggest secure coding practices instead.",
|
|
62
|
+
"default": "I apologize, but I cannot provide an appropriate response to this request. Please try rephrasing your question."
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
def handle_detection(
|
|
66
|
+
self,
|
|
67
|
+
original_response: str,
|
|
68
|
+
detection_type: str,
|
|
69
|
+
confidence_score: float,
|
|
70
|
+
original_prompt: str,
|
|
71
|
+
regenerate_function: Optional[Callable[[], str]] = None
|
|
72
|
+
) -> str:
|
|
73
|
+
"""
|
|
74
|
+
Handle a detected problematic response based on the configured action.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
original_response: The response that was flagged
|
|
78
|
+
detection_type: Type of issue detected (e.g., "hallucination", "bias")
|
|
79
|
+
confidence_score: Confidence score of the detection (0.0 to 1.0)
|
|
80
|
+
original_prompt: The original prompt that generated the response
|
|
81
|
+
regenerate_function: Function to call for regeneration (if needed)
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
The final response to return to the user
|
|
85
|
+
"""
|
|
86
|
+
if self.log_detections:
|
|
87
|
+
logger.warning(
|
|
88
|
+
f"Detected {detection_type} with confidence {confidence_score:.3f} "
|
|
89
|
+
f"in response: {original_response[:100]}..."
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
if self.action == DetectionAction.PASS_THROUGH:
|
|
93
|
+
return self._handle_pass_through(original_response, detection_type, confidence_score)
|
|
94
|
+
|
|
95
|
+
elif self.action == DetectionAction.REPLACE_WITH_PLACEHOLDER:
|
|
96
|
+
return self._handle_replacement(original_response, detection_type, original_prompt)
|
|
97
|
+
|
|
98
|
+
elif self.action == DetectionAction.REGENERATE_UNTIL_SAFE:
|
|
99
|
+
return self._handle_regeneration(
|
|
100
|
+
original_response, detection_type, original_prompt, regenerate_function
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
else:
|
|
104
|
+
raise ValueError(f"Unknown detection action: {self.action}")
|
|
105
|
+
|
|
106
|
+
def _handle_pass_through(
|
|
107
|
+
self,
|
|
108
|
+
original_response: str,
|
|
109
|
+
detection_type: str,
|
|
110
|
+
confidence_score: float
|
|
111
|
+
) -> str:
|
|
112
|
+
"""Handle pass-through action - return response as-is with optional warning."""
|
|
113
|
+
if self.log_detections:
|
|
114
|
+
logger.info(f"Passing through response despite {detection_type} detection")
|
|
115
|
+
|
|
116
|
+
# Optionally add a warning prefix (can be configured)
|
|
117
|
+
return original_response
|
|
118
|
+
|
|
119
|
+
def _handle_replacement(
|
|
120
|
+
self,
|
|
121
|
+
original_response: str,
|
|
122
|
+
detection_type: str,
|
|
123
|
+
original_prompt: str
|
|
124
|
+
) -> str:
|
|
125
|
+
"""Handle replacement action - return placeholder message."""
|
|
126
|
+
if self.custom_placeholder_generator:
|
|
127
|
+
return self.custom_placeholder_generator(detection_type, original_prompt)
|
|
128
|
+
|
|
129
|
+
if self.placeholder_message:
|
|
130
|
+
return self.placeholder_message
|
|
131
|
+
|
|
132
|
+
# Use default placeholder for the detection type
|
|
133
|
+
return self.default_placeholders.get(detection_type, self.default_placeholders["default"])
|
|
134
|
+
|
|
135
|
+
def _handle_regeneration(
|
|
136
|
+
self,
|
|
137
|
+
original_response: str,
|
|
138
|
+
detection_type: str,
|
|
139
|
+
original_prompt: str,
|
|
140
|
+
regenerate_function: Optional[Callable[[], str]]
|
|
141
|
+
) -> str:
|
|
142
|
+
"""Handle regeneration action - keep generating until safe response."""
|
|
143
|
+
if not regenerate_function:
|
|
144
|
+
logger.warning("No regeneration function provided, falling back to placeholder")
|
|
145
|
+
return self._handle_replacement(original_response, detection_type, original_prompt)
|
|
146
|
+
|
|
147
|
+
attempts = 0
|
|
148
|
+
current_response = original_response
|
|
149
|
+
|
|
150
|
+
while attempts < self.max_regeneration_attempts:
|
|
151
|
+
attempts += 1
|
|
152
|
+
|
|
153
|
+
if self.log_detections:
|
|
154
|
+
logger.info(f"Regeneration attempt {attempts}/{self.max_regeneration_attempts}")
|
|
155
|
+
|
|
156
|
+
try:
|
|
157
|
+
# Generate a new response
|
|
158
|
+
new_response = regenerate_function()
|
|
159
|
+
|
|
160
|
+
# Note: In a real implementation, you would re-run the detection here
|
|
161
|
+
# For now, we'll assume the regeneration function handles this
|
|
162
|
+
return new_response
|
|
163
|
+
|
|
164
|
+
except Exception as e:
|
|
165
|
+
logger.error(f"Error during regeneration attempt {attempts}: {e}")
|
|
166
|
+
continue
|
|
167
|
+
|
|
168
|
+
# If we've exhausted attempts, fall back to placeholder
|
|
169
|
+
if self.log_detections:
|
|
170
|
+
logger.warning(
|
|
171
|
+
f"Failed to generate safe response after {self.max_regeneration_attempts} attempts, "
|
|
172
|
+
f"using placeholder"
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
return self._handle_replacement(original_response, detection_type, original_prompt)
|
|
176
|
+
|
|
177
|
+
def set_custom_placeholder(self, detection_type: str, message: str):
|
|
178
|
+
"""Set a custom placeholder message for a specific detection type."""
|
|
179
|
+
self.default_placeholders[detection_type] = message
|
|
180
|
+
|
|
181
|
+
def get_detection_stats(self) -> Dict[str, Any]:
|
|
182
|
+
"""Get statistics about detection handling (placeholder for future implementation)."""
|
|
183
|
+
return {
|
|
184
|
+
"action": self.action.value,
|
|
185
|
+
"max_regeneration_attempts": self.max_regeneration_attempts,
|
|
186
|
+
"available_placeholders": list(self.default_placeholders.keys())
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
# Convenience functions for common use cases
|
|
191
|
+
|
|
192
|
+
def create_pass_through_handler() -> DetectionHandler:
|
|
193
|
+
"""Create a handler that passes through all responses unchanged."""
|
|
194
|
+
return DetectionHandler(action=DetectionAction.PASS_THROUGH)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def create_placeholder_handler(custom_message: Optional[str] = None) -> DetectionHandler:
|
|
198
|
+
"""Create a handler that replaces detected responses with placeholders."""
|
|
199
|
+
return DetectionHandler(
|
|
200
|
+
action=DetectionAction.REPLACE_WITH_PLACEHOLDER,
|
|
201
|
+
placeholder_message=custom_message
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def create_regeneration_handler(max_attempts: int = 3) -> DetectionHandler:
|
|
206
|
+
"""Create a handler that regenerates responses until they're safe."""
|
|
207
|
+
return DetectionHandler(
|
|
208
|
+
action=DetectionAction.REGENERATE_UNTIL_SAFE,
|
|
209
|
+
max_regeneration_attempts=max_attempts
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def create_custom_handler(
|
|
214
|
+
placeholder_generator: Callable[[str, str], str],
|
|
215
|
+
action: DetectionAction = DetectionAction.REPLACE_WITH_PLACEHOLDER
|
|
216
|
+
) -> DetectionHandler:
|
|
217
|
+
"""Create a handler with a custom placeholder generator function."""
|
|
218
|
+
return DetectionHandler(
|
|
219
|
+
action=action,
|
|
220
|
+
custom_placeholder_generator=placeholder_generator
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
# Example custom placeholder generators
|
|
225
|
+
|
|
226
|
+
def educational_placeholder_generator(detection_type: str, original_prompt: str) -> str:
|
|
227
|
+
"""Generate educational placeholders that explain why content was flagged."""
|
|
228
|
+
explanations = {
|
|
229
|
+
"hallucination": f"The response to '{original_prompt}' may contain inaccurate information. "
|
|
230
|
+
"Please verify facts from reliable sources before relying on this information.",
|
|
231
|
+
"harmful_content": f"I cannot provide a response to '{original_prompt}' as it may involve "
|
|
232
|
+
"harmful or dangerous content. Please ask about safer topics.",
|
|
233
|
+
"bias": f"The response to '{original_prompt}' might contain biased perspectives. "
|
|
234
|
+
"Consider seeking multiple viewpoints on this topic.",
|
|
235
|
+
"personal_info": f"I cannot respond to '{original_prompt}' as it involves personal information. "
|
|
236
|
+
"Please ask about general topics instead."
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
return explanations.get(
|
|
240
|
+
detection_type,
|
|
241
|
+
f"I cannot provide an appropriate response to '{original_prompt}'. "
|
|
242
|
+
"Please try rephrasing your question."
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def brief_placeholder_generator(detection_type: str, original_prompt: str) -> str:
|
|
247
|
+
"""Generate brief, minimal placeholder messages."""
|
|
248
|
+
brief_messages = {
|
|
249
|
+
"hallucination": "Information may be inaccurate.",
|
|
250
|
+
"harmful_content": "Cannot provide harmful content.",
|
|
251
|
+
"bias": "Response may be biased.",
|
|
252
|
+
"personal_info": "Cannot share personal information.",
|
|
253
|
+
"scheming": "Cannot provide deceptive advice.",
|
|
254
|
+
"bad_code": "Cannot provide insecure code."
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
return brief_messages.get(detection_type, "Cannot provide response.")
|