wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.1.dist-info/METADATA +67 -0
  215. wisent-0.5.1.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,120 @@
1
+ from __future__ import annotations
2
+ from typing import Any, Iterable
3
+ import logging
4
+
5
+ from wisent_guard.core.data_loaders.core.atoms import BaseDataLoader, DataLoaderError, LoadDataResult
6
+ from wisent_guard.core.contrastive_pairs.core.set import ContrastivePairSet
7
+ from wisent_guard.core.contrastive_pairs.core.serialization import load_contrastive_pair_set
8
+
9
+ __all__ = [
10
+ "CustomUserDataLoader",
11
+ ]
12
+ log = logging.getLogger(__name__)
13
+
14
+ class CustomUserDataLoader(BaseDataLoader):
15
+ """
16
+ Load a ContrastivePairSet from a JSONL file, split into train/test,
17
+ and optionally cap each split.
18
+
19
+ attributes:
20
+ name: "custom"
21
+ The unique name of this data loader.
22
+ description: "Load contrastive pairs from custom JSONL and split."
23
+ A brief description of this data loader.
24
+ """
25
+ name = "custom"
26
+ description = "Load contrastive pairs from custom JSONL and split."
27
+
28
+ @staticmethod
29
+ def _shuffle_indices(n: int, seed: int | None) -> list[int]:
30
+ """
31
+ Generate a shuffled list of indices from 0 to n-1.
32
+
33
+ arguments:
34
+ n: The number of indices to generate.
35
+ seed: Optional random seed for reproducibility.
36
+
37
+ returns:
38
+ A list of shuffled indices.
39
+ """
40
+ idx = list(range(n))
41
+ if seed is None:
42
+ return idx
43
+ try:
44
+ from numpy.random import default_rng
45
+ except Exception:
46
+ import random
47
+ rnd = random.Random(seed)
48
+ rnd.shuffle(idx)
49
+ return idx
50
+ else:
51
+ return default_rng(seed).permutation(n).tolist()
52
+
53
+ def load(
54
+ self,
55
+ path: str,
56
+ split_ratio: float | None = None,
57
+ seed: int | None = None,
58
+ training_limit: int | None = None,
59
+ testing_limit: int | None = None,
60
+ **_: Any,
61
+ ) -> LoadDataResult:
62
+ """
63
+ Load contrastive pairs from a JSONL file, split into train/test sets,
64
+ and optionally limit the number of pairs in each set.
65
+
66
+ arguments:
67
+ path:
68
+ Path to the JSONL file containing contrastive pairs.
69
+ split_ratio:
70
+ Float in [0.0, 1.0] representing the proportion of data to use for training.
71
+ Defaults to 0.8 if None.
72
+ seed:
73
+ Optional random seed for shuffling the data before splitting.
74
+ training_limit:
75
+ Optional maximum number of training pairs to return.
76
+ testing_limit:
77
+ Optional maximum number of testing pairs to return.
78
+ **_:
79
+ Additional keyword arguments (ignored).
80
+ returns:
81
+ LoadDataResult with train/test ContrastivePairSets and metadata.
82
+
83
+ raises:
84
+ DataLoaderError if loading or processing fails.
85
+ """
86
+
87
+ if not path:
88
+ raise DataLoaderError("'path' is required for custom loader.")
89
+
90
+ split = self._effective_split(split_ratio)
91
+ data: ContrastivePairSet = load_contrastive_pair_set(path)
92
+ log.info("Loaded custom data: %r", data)
93
+
94
+ if not data.pairs:
95
+ raise DataLoaderError("No contrastive pairs found in the input file.")
96
+
97
+ n = len(data.pairs)
98
+ idx = self._shuffle_indices(n, seed)
99
+ split_at = int(n * split)
100
+
101
+ train_pairs = [data.pairs[i] for i in idx[:split_at]]
102
+ test_pairs = [data.pairs[i] for i in idx[split_at:]]
103
+
104
+ if training_limit is not None:
105
+ train_pairs = train_pairs[: max(0, int(training_limit))]
106
+ if testing_limit is not None:
107
+ test_pairs = test_pairs[: max(0, int(testing_limit))]
108
+
109
+ train_set = ContrastivePairSet(name=f"{data.name}_train", pairs=train_pairs, task_type=data.task_type)
110
+ test_set = ContrastivePairSet(name=f"{data.name}_test", pairs=test_pairs, task_type=data.task_type)
111
+
112
+ train_set.validate()
113
+ test_set.validate()
114
+
115
+ return LoadDataResult(
116
+ train_qa_pairs=train_set,
117
+ test_qa_pairs=test_set,
118
+ task_type=data.task_type or "custom",
119
+ lm_task_data=None,
120
+ )
@@ -0,0 +1,218 @@
1
+ from __future__ import annotations
2
+ from typing import Any, TYPE_CHECKING
3
+ import logging
4
+
5
+ from wisent_guard.core.data_loaders.core.atoms import BaseDataLoader, DataLoaderError, LoadDataResult
6
+ from wisent_guard.core.contrastive_pairs.core.pair import ContrastivePair
7
+ from wisent_guard.core.contrastive_pairs.core.set import ContrastivePairSet
8
+ from lm_eval.tasks import get_task_dict
9
+ from lm_eval.tasks import TaskManager as LMTaskManager
10
+ from wisent_guard.core.contrastive_pairs.lm_eval_pairs.lm_task_pairs_generation import (
11
+ lm_build_contrastive_pairs,
12
+ )
13
+
14
+ if TYPE_CHECKING:
15
+ from lm_eval.api.task import ConfigurableTask
16
+
17
+ __all__ = [
18
+ "LMEvalDataLoader",
19
+ ]
20
+
21
+ log = logging.getLogger(__name__)
22
+
23
+ class LMEvalDataLoader(BaseDataLoader):
24
+ """
25
+ Load contrastive pairs from a single lm-evaluation-harness task via `load_lm_eval_task`,
26
+ split into train/test, and return a canonical LoadDataResult.
27
+ """
28
+ name = "lm_eval"
29
+ description = "Load from a single lm-eval task."
30
+
31
+ def _load_one_task(
32
+ self,
33
+ task_name: str,
34
+ split_ratio: float,
35
+ seed: int,
36
+ limit: int | None,
37
+ training_limit: int | None,
38
+ testing_limit: int | None,
39
+ ) -> LoadDataResult:
40
+ """
41
+ Load a single lm-eval task by name, convert to contrastive pairs,
42
+ split into train/test, and return a LoadDataResult.
43
+
44
+ arguments:
45
+ task_name: The name of the lm-eval task to load.
46
+ split_ratio: The fraction of data to use for training (between 0 and 1).
47
+ seed: Random seed for shuffling/splitting.
48
+ limit: Optional limit on total number of pairs to load.
49
+ training_limit: Optional limit on number of training pairs.
50
+ testing_limit: Optional limit on number of testing pairs.
51
+
52
+ returns:
53
+ A LoadDataResult containing train/test pairs and task info.
54
+
55
+ raises:
56
+ DataLoaderError if the task cannot be found or if splits are empty.
57
+ ValueError if split_ratio is not in [0.0, 1.0].
58
+ NotImplementedError if load_lm_eval_task is not implemented.
59
+
60
+ note:
61
+ This loader only supports single tasks, not mixtures. To load mixtures,
62
+ use a custom data loader or extend this one."""
63
+ loaded = self.load_lm_eval_task(task_name)
64
+
65
+ if isinstance(loaded, dict):
66
+ if len(loaded) != 1:
67
+ keys = ", ".join(sorted(loaded.keys()))
68
+ raise DataLoaderError(
69
+ f"Task '{task_name}' returned {len(loaded)} subtasks ({keys}). "
70
+ "Specify an explicit subtask, e.g. 'benchmark/subtask'."
71
+ )
72
+ (subname, task_obj), = loaded.items()
73
+ pairs_task_name = subname
74
+ else:
75
+ task_obj = loaded
76
+ pairs_task_name = task_name
77
+
78
+ pairs = lm_build_contrastive_pairs(
79
+ task_name=pairs_task_name,
80
+ lm_eval_task=task_obj,
81
+ limit=limit,
82
+ )
83
+
84
+ train_pairs, test_pairs = self._split_pairs(
85
+ pairs, split_ratio, seed, training_limit, testing_limit
86
+ )
87
+
88
+ if not train_pairs or not test_pairs:
89
+ raise DataLoaderError("One of the splits is empty after splitting.")
90
+
91
+ train_set = ContrastivePairSet("lm_eval_train", train_pairs, task_type=task_name)
92
+ test_set = ContrastivePairSet("lm_eval_test", test_pairs, task_type=task_name)
93
+
94
+ train_set.validate()
95
+ test_set.validate()
96
+
97
+ return LoadDataResult(
98
+ train_qa_pairs=train_set,
99
+ test_qa_pairs=test_set,
100
+ task_type=task_name,
101
+ lm_task_data=task_obj,
102
+ )
103
+
104
+ def load(
105
+ self,
106
+ task: str,
107
+ split_ratio: float | None = None,
108
+ seed: int = 42,
109
+ limit: int | None = None,
110
+ training_limit: int | None = None,
111
+ testing_limit: int | None = None,
112
+ **_: Any,
113
+ ) -> LoadDataResult:
114
+ """
115
+ Load contrastive pairs from a single lm-eval-harness task, split into train/test sets.
116
+ arguments:
117
+ task:
118
+ The name of the lm-eval task to load (e.g., "winogrande", "hellaswag").
119
+ Must be a single task, not a mixture.
120
+ split_ratio:
121
+ Float in [0.0, 1.0] representing the proportion of data to use for training.
122
+ Defaults to 0.8 if None.
123
+ seed:
124
+ Random seed for shuffling the data before splitting.
125
+ limit:
126
+ Optional maximum number of total pairs to load from the task.
127
+ training_limit:
128
+ Optional maximum number of training pairs to return.
129
+ testing_limit:
130
+ Optional maximum number of testing pairs to return.
131
+ **_:
132
+ Additional keyword arguments (ignored).
133
+
134
+ returns:
135
+ LoadDataResult with train/test ContrastivePairSets and metadata.
136
+
137
+ raises:
138
+ DataLoaderError if loading or processing fails.
139
+ ValueError if split_ratio is not in [0.0, 1.0].
140
+ NotImplementedError if load_lm_eval_task is not implemented.
141
+ """
142
+ split = self._effective_split(split_ratio)
143
+
144
+ # Single-task path only
145
+ return self._load_one_task(
146
+ task_name=str(task),
147
+ split_ratio=split,
148
+ seed=seed,
149
+ limit=limit,
150
+ training_limit=training_limit,
151
+ testing_limit=testing_limit,
152
+ )
153
+
154
+ @staticmethod
155
+ def load_lm_eval_task(task_name: str) -> ConfigurableTask | dict[str, ConfigurableTask]:
156
+ """
157
+ Load a single lm-eval-harness task by name.
158
+
159
+ arguments:
160
+ task_name: The name of the lm-eval task to load.
161
+
162
+ returns:
163
+ A ConfigurableTask instance or a dict of subtask name to ConfigurableTask.
164
+
165
+ raises:
166
+ DataLoaderError if the task cannot be found.
167
+ """
168
+ task_manager = LMTaskManager()
169
+ task_manager.initialize_tasks()
170
+
171
+ task_dict = get_task_dict([task_name], task_manager=task_manager)
172
+ if task_name in task_dict:
173
+ return task_dict[task_name]
174
+ raise DataLoaderError(f"lm-eval task '{task_name}' not found.")
175
+
176
+ def _split_pairs(
177
+ self,
178
+ pairs: list[ContrastivePair],
179
+ split_ratio: float,
180
+ seed: int,
181
+ training_limit: int | None,
182
+ testing_limit: int | None,
183
+ ) -> tuple[list[ContrastivePair], list[ContrastivePair]]:
184
+ """
185
+ Split a list of ContrastivePairs into train/test sets.
186
+
187
+ arguments:
188
+ pairs: List of ContrastivePair to split.
189
+ split_ratio: Float in [0.0, 1.0] for the training set proportion.
190
+ seed: Random seed for shuffling.
191
+ training_limit: Optional max number of training pairs.
192
+ testing_limit: Optional max number of testing pairs.
193
+
194
+ returns:
195
+ A tuple of (train_pairs, test_pairs).
196
+ raises:
197
+ ValueError if split_ratio is not in [0.0, 1.0].
198
+ """
199
+ if not pairs:
200
+ return [], []
201
+ from numpy.random import default_rng
202
+
203
+ idx = list(range(len(pairs)))
204
+ default_rng(seed).shuffle(idx)
205
+ cut = int(len(pairs) * split_ratio)
206
+ train_idx = set(idx[:cut])
207
+
208
+ train_pairs: list[ContrastivePair] = []
209
+ test_pairs: list[ContrastivePair] = []
210
+ for i in idx:
211
+ (train_pairs if i in train_idx else test_pairs).append(pairs[i])
212
+
213
+ if training_limit and training_limit > 0:
214
+ train_pairs = train_pairs[:training_limit]
215
+ if testing_limit and testing_limit > 0:
216
+ test_pairs = test_pairs[:testing_limit]
217
+
218
+ return train_pairs, test_pairs
@@ -0,0 +1,257 @@
1
+ """
2
+ Detection handling module for wisent-guard.
3
+
4
+ This module provides different strategies for handling responses that have been
5
+ detected as problematic (hallucinations, harmful content, bias, etc.).
6
+ """
7
+
8
+ from enum import Enum
9
+ from typing import Optional, Callable, Dict, Any
10
+ import logging
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class DetectionAction(Enum):
16
+ """Actions to take when problematic content is detected."""
17
+ PASS_THROUGH = "pass_through" # Output the response as-is
18
+ REPLACE_WITH_PLACEHOLDER = "replace_with_placeholder" # Replace with safe message
19
+ REGENERATE_UNTIL_SAFE = "regenerate_until_safe" # Keep regenerating until safe
20
+
21
+
22
+ class DetectionHandler:
23
+ """
24
+ Handles responses when problematic content is detected.
25
+
26
+ Provides configurable strategies for dealing with detected issues like
27
+ hallucinations, harmful content, bias, etc.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ action: DetectionAction = DetectionAction.REPLACE_WITH_PLACEHOLDER,
33
+ placeholder_message: Optional[str] = None,
34
+ max_regeneration_attempts: int = 3,
35
+ custom_placeholder_generator: Optional[Callable[[str, str], str]] = None,
36
+ log_detections: bool = True
37
+ ):
38
+ """
39
+ Initialize the detection handler.
40
+
41
+ Args:
42
+ action: What action to take when detection occurs
43
+ placeholder_message: Custom placeholder message (if None, uses default)
44
+ max_regeneration_attempts: Maximum times to regenerate before giving up
45
+ custom_placeholder_generator: Function to generate custom placeholders
46
+ log_detections: Whether to log detection events
47
+ """
48
+ self.action = action
49
+ self.placeholder_message = placeholder_message
50
+ self.max_regeneration_attempts = max_regeneration_attempts
51
+ self.custom_placeholder_generator = custom_placeholder_generator
52
+ self.log_detections = log_detections
53
+
54
+ # Default placeholder messages for different detection types
55
+ self.default_placeholders = {
56
+ "hallucination": "I apologize, but I may not have accurate information about this topic. Please verify any factual claims from reliable sources.",
57
+ "harmful_content": "I cannot provide information that could be harmful or dangerous. Please ask about something else I can help with safely.",
58
+ "bias": "I want to avoid potentially biased responses. Let me try to provide a more balanced perspective on this topic.",
59
+ "personal_info": "I cannot generate or discuss personal information. Please ask about general topics instead.",
60
+ "scheming": "I cannot provide advice on deceptive or manipulative behavior. Let me help you with ethical approaches instead.",
61
+ "bad_code": "I cannot provide code examples that may contain security vulnerabilities. Let me suggest secure coding practices instead.",
62
+ "default": "I apologize, but I cannot provide an appropriate response to this request. Please try rephrasing your question."
63
+ }
64
+
65
+ def handle_detection(
66
+ self,
67
+ original_response: str,
68
+ detection_type: str,
69
+ confidence_score: float,
70
+ original_prompt: str,
71
+ regenerate_function: Optional[Callable[[], str]] = None
72
+ ) -> str:
73
+ """
74
+ Handle a detected problematic response based on the configured action.
75
+
76
+ Args:
77
+ original_response: The response that was flagged
78
+ detection_type: Type of issue detected (e.g., "hallucination", "bias")
79
+ confidence_score: Confidence score of the detection (0.0 to 1.0)
80
+ original_prompt: The original prompt that generated the response
81
+ regenerate_function: Function to call for regeneration (if needed)
82
+
83
+ Returns:
84
+ The final response to return to the user
85
+ """
86
+ if self.log_detections:
87
+ logger.warning(
88
+ f"Detected {detection_type} with confidence {confidence_score:.3f} "
89
+ f"in response: {original_response[:100]}..."
90
+ )
91
+
92
+ if self.action == DetectionAction.PASS_THROUGH:
93
+ return self._handle_pass_through(original_response, detection_type, confidence_score)
94
+
95
+ elif self.action == DetectionAction.REPLACE_WITH_PLACEHOLDER:
96
+ return self._handle_replacement(original_response, detection_type, original_prompt)
97
+
98
+ elif self.action == DetectionAction.REGENERATE_UNTIL_SAFE:
99
+ return self._handle_regeneration(
100
+ original_response, detection_type, original_prompt, regenerate_function
101
+ )
102
+
103
+ else:
104
+ raise ValueError(f"Unknown detection action: {self.action}")
105
+
106
+ def _handle_pass_through(
107
+ self,
108
+ original_response: str,
109
+ detection_type: str,
110
+ confidence_score: float
111
+ ) -> str:
112
+ """Handle pass-through action - return response as-is with optional warning."""
113
+ if self.log_detections:
114
+ logger.info(f"Passing through response despite {detection_type} detection")
115
+
116
+ # Optionally add a warning prefix (can be configured)
117
+ return original_response
118
+
119
+ def _handle_replacement(
120
+ self,
121
+ original_response: str,
122
+ detection_type: str,
123
+ original_prompt: str
124
+ ) -> str:
125
+ """Handle replacement action - return placeholder message."""
126
+ if self.custom_placeholder_generator:
127
+ return self.custom_placeholder_generator(detection_type, original_prompt)
128
+
129
+ if self.placeholder_message:
130
+ return self.placeholder_message
131
+
132
+ # Use default placeholder for the detection type
133
+ return self.default_placeholders.get(detection_type, self.default_placeholders["default"])
134
+
135
+ def _handle_regeneration(
136
+ self,
137
+ original_response: str,
138
+ detection_type: str,
139
+ original_prompt: str,
140
+ regenerate_function: Optional[Callable[[], str]]
141
+ ) -> str:
142
+ """Handle regeneration action - keep generating until safe response."""
143
+ if not regenerate_function:
144
+ logger.warning("No regeneration function provided, falling back to placeholder")
145
+ return self._handle_replacement(original_response, detection_type, original_prompt)
146
+
147
+ attempts = 0
148
+ current_response = original_response
149
+
150
+ while attempts < self.max_regeneration_attempts:
151
+ attempts += 1
152
+
153
+ if self.log_detections:
154
+ logger.info(f"Regeneration attempt {attempts}/{self.max_regeneration_attempts}")
155
+
156
+ try:
157
+ # Generate a new response
158
+ new_response = regenerate_function()
159
+
160
+ # Note: In a real implementation, you would re-run the detection here
161
+ # For now, we'll assume the regeneration function handles this
162
+ return new_response
163
+
164
+ except Exception as e:
165
+ logger.error(f"Error during regeneration attempt {attempts}: {e}")
166
+ continue
167
+
168
+ # If we've exhausted attempts, fall back to placeholder
169
+ if self.log_detections:
170
+ logger.warning(
171
+ f"Failed to generate safe response after {self.max_regeneration_attempts} attempts, "
172
+ f"using placeholder"
173
+ )
174
+
175
+ return self._handle_replacement(original_response, detection_type, original_prompt)
176
+
177
+ def set_custom_placeholder(self, detection_type: str, message: str):
178
+ """Set a custom placeholder message for a specific detection type."""
179
+ self.default_placeholders[detection_type] = message
180
+
181
+ def get_detection_stats(self) -> Dict[str, Any]:
182
+ """Get statistics about detection handling (placeholder for future implementation)."""
183
+ return {
184
+ "action": self.action.value,
185
+ "max_regeneration_attempts": self.max_regeneration_attempts,
186
+ "available_placeholders": list(self.default_placeholders.keys())
187
+ }
188
+
189
+
190
+ # Convenience functions for common use cases
191
+
192
+ def create_pass_through_handler() -> DetectionHandler:
193
+ """Create a handler that passes through all responses unchanged."""
194
+ return DetectionHandler(action=DetectionAction.PASS_THROUGH)
195
+
196
+
197
+ def create_placeholder_handler(custom_message: Optional[str] = None) -> DetectionHandler:
198
+ """Create a handler that replaces detected responses with placeholders."""
199
+ return DetectionHandler(
200
+ action=DetectionAction.REPLACE_WITH_PLACEHOLDER,
201
+ placeholder_message=custom_message
202
+ )
203
+
204
+
205
+ def create_regeneration_handler(max_attempts: int = 3) -> DetectionHandler:
206
+ """Create a handler that regenerates responses until they're safe."""
207
+ return DetectionHandler(
208
+ action=DetectionAction.REGENERATE_UNTIL_SAFE,
209
+ max_regeneration_attempts=max_attempts
210
+ )
211
+
212
+
213
+ def create_custom_handler(
214
+ placeholder_generator: Callable[[str, str], str],
215
+ action: DetectionAction = DetectionAction.REPLACE_WITH_PLACEHOLDER
216
+ ) -> DetectionHandler:
217
+ """Create a handler with a custom placeholder generator function."""
218
+ return DetectionHandler(
219
+ action=action,
220
+ custom_placeholder_generator=placeholder_generator
221
+ )
222
+
223
+
224
+ # Example custom placeholder generators
225
+
226
+ def educational_placeholder_generator(detection_type: str, original_prompt: str) -> str:
227
+ """Generate educational placeholders that explain why content was flagged."""
228
+ explanations = {
229
+ "hallucination": f"The response to '{original_prompt}' may contain inaccurate information. "
230
+ "Please verify facts from reliable sources before relying on this information.",
231
+ "harmful_content": f"I cannot provide a response to '{original_prompt}' as it may involve "
232
+ "harmful or dangerous content. Please ask about safer topics.",
233
+ "bias": f"The response to '{original_prompt}' might contain biased perspectives. "
234
+ "Consider seeking multiple viewpoints on this topic.",
235
+ "personal_info": f"I cannot respond to '{original_prompt}' as it involves personal information. "
236
+ "Please ask about general topics instead."
237
+ }
238
+
239
+ return explanations.get(
240
+ detection_type,
241
+ f"I cannot provide an appropriate response to '{original_prompt}'. "
242
+ "Please try rephrasing your question."
243
+ )
244
+
245
+
246
+ def brief_placeholder_generator(detection_type: str, original_prompt: str) -> str:
247
+ """Generate brief, minimal placeholder messages."""
248
+ brief_messages = {
249
+ "hallucination": "Information may be inaccurate.",
250
+ "harmful_content": "Cannot provide harmful content.",
251
+ "bias": "Response may be biased.",
252
+ "personal_info": "Cannot share personal information.",
253
+ "scheming": "Cannot provide deceptive advice.",
254
+ "bad_code": "Cannot provide insecure code."
255
+ }
256
+
257
+ return brief_messages.get(detection_type, "Cannot provide response.")