wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.2.dist-info/METADATA +67 -0
  215. wisent-0.5.2.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
File without changes
@@ -0,0 +1,29 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from wisent.classifiers.core.atoms import BaseClassifier
7
+
8
+ __all__ = ["LogisticClassifier"]
9
+
10
+ class LogisticModel(nn.Module):
11
+ """Simple logistic regression model for activation classification."""
12
+ def __init__(self, input_dim: int):
13
+ super().__init__()
14
+ self.linear = nn.Linear(input_dim, 1)
15
+ self.sigmoid = nn.Sigmoid()
16
+
17
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
18
+ logits = self.linear(x)
19
+ if logits.ndim == 1:
20
+ logits = logits.unsqueeze(1)
21
+ return self.sigmoid(logits)
22
+
23
+
24
+ class LogisticClassifier(BaseClassifier):
25
+ name = "logistic"
26
+ description = "One-layer logistic regression over dense features"
27
+
28
+ def build_model(self, input_dim: int, **_: object) -> nn.Module:
29
+ return LogisticModel(input_dim)
@@ -0,0 +1,47 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from wisent.classifiers.core.atoms import BaseClassifier
7
+
8
+ __all__ = ["MLPClassifier"]
9
+
10
+ class MLPModel(nn.Module):
11
+ """Multi-layer perceptron for activation classification."""
12
+ def __init__(self, input_dim: int, hidden_dim: int = 128, dropout: float = 0.2):
13
+ super().__init__()
14
+ self.net = nn.Sequential(
15
+ nn.Linear(input_dim, hidden_dim),
16
+ nn.ReLU(),
17
+ nn.Dropout(dropout),
18
+ nn.Linear(hidden_dim, hidden_dim // 2),
19
+ nn.ReLU(),
20
+ nn.Dropout(dropout),
21
+ nn.Linear(hidden_dim // 2, 1),
22
+ nn.Sigmoid(),
23
+ )
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ out = self.net(x)
27
+ if out.ndim == 1:
28
+ out = out.unsqueeze(1)
29
+ return out
30
+
31
+
32
+ class MLPClassifier(BaseClassifier):
33
+ name = "mlp"
34
+ description = "Two-layer MLP with dropout and ReLU"
35
+
36
+ def __init__(self, *, hidden_dim: int = 128, **base_kwargs):
37
+ super().__init__(**base_kwargs)
38
+ self._hidden_dim = int(hidden_dim)
39
+
40
+ def build_model(self, input_dim: int, **model_params: object) -> nn.Module:
41
+ hd = int(model_params.get("hidden_dim", self._hidden_dim))
42
+ dp = float(model_params.get("dropout", 0.2))
43
+ self._hidden_dim = hd
44
+ return MLPModel(input_dim, hidden_dim=hd, dropout=dp)
45
+
46
+ def model_hyperparams(self) -> dict[str, int]:
47
+ return {"hidden_dim": self._hidden_dim, "dropout": 0.2}
wisent/cli/__init__.py ADDED
File without changes
File without changes
@@ -0,0 +1,137 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+ import importlib.util
5
+ import inspect
6
+ import pkgutil
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ from wisent.core.classifiers.core.atoms import BaseClassifier, ClassifierError, ClassifierTrainReport
11
+
12
+ __all__ = ["ClassifierRotator"]
13
+
14
+ class ClassifierRotator:
15
+ """
16
+ Discover, list, and delegate to registered classifiers.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ classifier: str | BaseClassifier | type[BaseClassifier] | None = None,
22
+ classifiers_location: str | Path = "wisent_guard.core.classifiers.models",
23
+ autoload: bool = True,
24
+ **classifier_kwargs: Any,
25
+ ) -> None:
26
+ if autoload:
27
+ self.discover_classifiers(classifiers_location)
28
+ self._classifier = self._resolve_classifier(classifier, **classifier_kwargs)
29
+
30
+ @staticmethod
31
+ def discover_classifiers(location: str | Path = "wisent_guard.core.classifiers.models") -> None:
32
+ """
33
+ Import all classifier modules so BaseClassifier subclasses self-register.
34
+
35
+ - If `location` is a dotted module path (str without existing FS path),
36
+ import that package and iterate its __path__ (works with namespace packages).
37
+ - If `location` is an existing directory (Path/str), import all .py files inside.
38
+ """
39
+ loc_path = Path(str(location))
40
+ if loc_path.exists() and loc_path.is_dir():
41
+ ClassifierRotator._import_all_py_in_dir(loc_path)
42
+ return
43
+
44
+ if not isinstance(location, str):
45
+ raise ClassifierError(
46
+ f"Invalid classifiers location: {location!r}. Provide a dotted module path or a directory."
47
+ )
48
+
49
+ try:
50
+ pkg = importlib.import_module(location)
51
+ except ModuleNotFoundError as exc:
52
+ raise ClassifierError(
53
+ f"Cannot import classifier package {location!r}. "
54
+ f"Use a dotted path (no leading slash) and ensure your project root is on PYTHONPATH."
55
+ ) from exc
56
+
57
+ search_paths = list(getattr(pkg, "__path__", []))
58
+ if not search_paths:
59
+ pkg_file = getattr(pkg, "__file__", None)
60
+ if pkg_file:
61
+ search_paths = [str(Path(pkg_file).parent)]
62
+
63
+ for _finder, name, _ispkg in pkgutil.iter_modules(search_paths):
64
+ if name.startswith("_"):
65
+ continue
66
+ importlib.import_module(f"{location}.{name}")
67
+
68
+ @staticmethod
69
+ def _import_all_py_in_dir(directory: Path) -> None:
70
+ for py in directory.glob("*.py"):
71
+ if py.name.startswith("_"):
72
+ continue
73
+ mod_name = f"_dyn_classifiers_{py.stem}"
74
+ spec = importlib.util.spec_from_file_location(mod_name, py)
75
+ if spec and spec.loader:
76
+ module = importlib.util.module_from_spec(spec)
77
+ spec.loader.exec_module(module) # type: ignore[attr-defined]
78
+
79
+ @staticmethod
80
+ def list_classifiers() -> list[dict[str, Any]]:
81
+ out: list[dict[str, Any]] = []
82
+ for name, cls in BaseClassifier.list_registered().items():
83
+ out.append(
84
+ {
85
+ "name": name,
86
+ "description": getattr(cls, "description", ""),
87
+ "class": f"{cls.__module__}.{cls.__name__}",
88
+ }
89
+ )
90
+ return sorted(out, key=lambda x: x["name"])
91
+
92
+ @staticmethod
93
+ def _resolve_classifier(
94
+ classifier: str | BaseClassifier | type[BaseClassifier] | None,
95
+ **kwargs: Any,
96
+ ) -> BaseClassifier:
97
+ if classifier is None:
98
+ registry = BaseClassifier.list_registered()
99
+ if not registry:
100
+ raise ClassifierError("No classifiers registered.")
101
+ # Deterministic pick: first by name
102
+ return next(iter(sorted(registry.items())))[1](**kwargs)
103
+ if isinstance(classifier, BaseClassifier):
104
+ return classifier
105
+ if inspect.isclass(classifier) and issubclass(classifier, BaseClassifier):
106
+ return classifier(**kwargs)
107
+ if isinstance(classifier, str):
108
+ cls = BaseClassifier.get(classifier)
109
+ return cls(**kwargs)
110
+ raise TypeError(
111
+ "classifier must be None, a name (str), BaseClassifier instance, or BaseClassifier subclass."
112
+ )
113
+
114
+
115
+ def use(self, classifier: str | BaseClassifier | type[BaseClassifier], **kwargs: Any) -> None:
116
+ self._classifier = self._resolve_classifier(classifier, **kwargs)
117
+
118
+ def fit(self, X, y, **kwargs) -> ClassifierTrainReport:
119
+ return self._classifier.fit(X, y, **kwargs)
120
+
121
+ def predict(self, X):
122
+ return self._classifier.predict(X)
123
+
124
+ def predict_proba(self, X):
125
+ return self._classifier.predict_proba(X)
126
+
127
+ def evaluate(self, X, y) -> dict[str, float]:
128
+ return self._classifier.evaluate(X, y)
129
+
130
+ def save_model(self, path: str) -> None:
131
+ self._classifier.save_model(path)
132
+
133
+ def load_model(self, path: str) -> None:
134
+ self._classifier.load_model(path)
135
+
136
+ def set_threshold(self, threshold: float) -> None:
137
+ self._classifier.set_threshold(threshold)
@@ -0,0 +1,142 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ import sys
6
+ from datetime import datetime, timezone
7
+ from typing import Any, Mapping
8
+
9
+ __all__ = [
10
+ "setup_logger",
11
+ "bind",
12
+ "JsonFormatter",
13
+ "ContextAdapter",
14
+ "add_file_handler",
15
+ ]
16
+
17
+ class JsonFormatter(logging.Formatter):
18
+ """
19
+ Minimal JSON formatter with structured fields + extras.
20
+ """
21
+ _STD = {
22
+ "name", "msg", "args", "levelname", "levelno", "pathname",
23
+ "filename", "module", "exc_info", "exc_text", "stack_info",
24
+ "lineno", "funcName", "created", "msecs", "relativeCreated",
25
+ "thread", "threadName", "processName", "process"
26
+ }
27
+
28
+ def format(self, record: logging.LogRecord) -> str:
29
+ payload: dict[str, Any] = {
30
+ "ts": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(),
31
+ "level": record.levelname,
32
+ "logger": record.name,
33
+ "message": record.getMessage(),
34
+ "file": record.filename,
35
+ "func": record.funcName,
36
+ "line": record.lineno,
37
+ }
38
+ extras = {
39
+ k: v for k, v in record.__dict__.items()
40
+ if k not in self._STD and not k.startswith("_")
41
+ }
42
+ if extras:
43
+ payload["extra"] = extras
44
+ if record.exc_info:
45
+ payload["exc"] = self.formatException(record.exc_info)
46
+ return json.dumps(payload, ensure_ascii=False)
47
+
48
+
49
+ class ContextAdapter(logging.LoggerAdapter):
50
+ """
51
+ LoggerAdapter that ensures persistent context fields appear in every log entry.
52
+ """
53
+ def process(self, msg, kwargs):
54
+ extra = kwargs.get("extra", {})
55
+ extra.update(self.extra or {})
56
+ kwargs["extra"] = extra
57
+ return msg, kwargs
58
+
59
+
60
+ class _EnsureContextFilter(logging.Filter):
61
+ """
62
+ Adds default values for context keys so format strings never KeyError.
63
+ """
64
+ def __init__(self, defaults: Mapping[str, Any] | None = None):
65
+ super().__init__()
66
+ self.defaults = dict(defaults or {})
67
+
68
+ def filter(self, record: logging.LogRecord) -> bool:
69
+ for k, v in self.defaults.items():
70
+ if not hasattr(record, k):
71
+ setattr(record, k, v)
72
+ return True
73
+
74
+
75
+ def setup_logger(
76
+ name: str = "wisent",
77
+ level: int = logging.INFO,
78
+ *,
79
+ json_logs: bool = False,
80
+ stream = sys.stderr,
81
+ ) -> logging.Logger:
82
+ """
83
+ Create or return a named logger with a single stream handler.
84
+ Safe to call multiple times; won’t duplicate handlers.
85
+ """
86
+ logger = logging.getLogger(name)
87
+ logger.setLevel(level)
88
+ if not logger.handlers:
89
+ handler = logging.StreamHandler(stream)
90
+ if json_logs:
91
+ handler.setFormatter(JsonFormatter())
92
+ else:
93
+ handler.setFormatter(logging.Formatter(
94
+ fmt="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s "
95
+ "[file=%(filename)s func=%(funcName)s line=%(lineno)d] "
96
+ "%(task_name)s%(subtask)s",
97
+ datefmt="%Y-%m-%dT%H:%M:%S%z",
98
+ ))
99
+ # ensure context placeholders always exist
100
+ handler.addFilter(_EnsureContextFilter({"task_name": "", "subtask": ""}))
101
+ logger.addHandler(handler)
102
+ logger.propagate = False
103
+ return logger
104
+
105
+
106
+ def add_file_handler(
107
+ logger: logging.Logger,
108
+ filepath: str,
109
+ *,
110
+ level: int | None = None,
111
+ json_logs: bool = False,
112
+ ) -> None:
113
+ """
114
+ Optionally add a file handler (e.g., for long-running CLI jobs).
115
+ """
116
+ fh = logging.FileHandler(filepath, encoding="utf-8")
117
+ fh.setLevel(level or logger.level)
118
+ if json_logs:
119
+ fh.setFormatter(JsonFormatter())
120
+ else:
121
+ fh.setFormatter(logging.Formatter(
122
+ fmt="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s "
123
+ "[file=%(filename)s func=%(funcName)s line=%(lineno)d] "
124
+ "%(task_name)s%(subtask)s",
125
+ datefmt="%Y-%m-%dT%H:%M:%S%z",
126
+ ))
127
+ fh.addFilter(_EnsureContextFilter({"task_name": "", "subtask": ""}))
128
+ logger.addHandler(fh)
129
+
130
+
131
+ def bind(
132
+ logger: logging.Logger | ContextAdapter,
133
+ **extra: Any
134
+ ) -> ContextAdapter:
135
+ """
136
+ Return a ContextAdapter with merged extras.
137
+ Works whether you pass a raw Logger or an existing ContextAdapter.
138
+ """
139
+ if isinstance(logger, ContextAdapter):
140
+ merged = {**logger.extra, **extra}
141
+ return ContextAdapter(logger.logger, merged)
142
+ return ContextAdapter(logger, extra)
File without changes
@@ -0,0 +1,96 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+ import importlib.util
5
+ import inspect
6
+ import pkgutil
7
+ from pathlib import Path
8
+ from typing import Any, Dict, List, Optional, Type, Union
9
+
10
+ from wisent.core.data_loaders.core.atoms import BaseDataLoader, DataLoaderError, LoadDataResult
11
+
12
+ class DataLoaderRotator:
13
+ """Discover/select a data loader and use it to load data."""
14
+ def __init__(
15
+ self,
16
+ loader: Union[str, BaseDataLoader, Type[BaseDataLoader], None] = None,
17
+ loaders_location: Union[str, Path] = "wisent_guard.core.data_loaders.loaders",
18
+ autoload: bool = True,
19
+ **default_loader_kwargs: Any,
20
+ ) -> None:
21
+ self._scope_prefix = (
22
+ loaders_location if isinstance(loaders_location, str)
23
+ else Path(loaders_location).as_posix().replace("/", ".")
24
+ )
25
+ if autoload:
26
+ self.discover_loaders(loaders_location)
27
+ self._loader = self._resolve_loader(loader, **default_loader_kwargs)
28
+
29
+ @staticmethod
30
+ def discover_loaders(location: Union[str, Path]) -> None:
31
+ loc_path = Path(str(location))
32
+ if loc_path.exists() and loc_path.is_dir():
33
+ for py in loc_path.glob("*.py"):
34
+ if py.name.startswith("_"):
35
+ continue
36
+ mod_name = f"_dyn_dataloaders_{py.stem}"
37
+ spec = importlib.util.spec_from_file_location(mod_name, py)
38
+ if spec and spec.loader:
39
+ module = importlib.util.module_from_spec(spec)
40
+ spec.loader.exec_module(module) # type: ignore[attr-defined]
41
+ return
42
+
43
+ if not isinstance(location, str):
44
+ raise DataLoaderError(f"Invalid loaders location: {location!r}. Provide dotted path or a directory.")
45
+
46
+ pkg = importlib.import_module(location)
47
+ search_paths = list(getattr(pkg, "__path__", [])) or [Path(getattr(pkg, "__file__", "")).parent.as_posix()]
48
+ for _, name, _ in pkgutil.iter_modules(search_paths):
49
+ if name.startswith("_"):
50
+ continue
51
+ importlib.import_module(f"{location}.{name}")
52
+
53
+ def _scoped_registry(self) -> dict[str, type[BaseDataLoader]]:
54
+ reg = BaseDataLoader.list_registered()
55
+ return {n: c for n, c in reg.items() if c.__module__.startswith(self._scope_prefix)}
56
+
57
+ @staticmethod
58
+ def list_loaders(scope_prefix: Optional[str] = None) -> List[Dict[str, Any]]:
59
+ reg = BaseDataLoader.list_registered()
60
+ if scope_prefix:
61
+ reg = {n: c for n, c in reg.items() if c.__module__.startswith(scope_prefix)}
62
+ return [
63
+ {"name": n, "description": getattr(c, "description", ""), "class": f"{c.__module__}.{c.__name__}"}
64
+ for n, c in sorted(reg.items(), key=lambda kv: kv[0])
65
+ ]
66
+
67
+ def _resolve_loader(
68
+ self,
69
+ loader: Union[str, BaseDataLoader, Type[BaseDataLoader], None],
70
+ **kwargs: Any,
71
+ ) -> BaseDataLoader:
72
+ reg = self._scoped_registry()
73
+ if loader is None:
74
+ if not reg:
75
+ raise DataLoaderError(f"No data loaders registered under {self._scope_prefix!r}.")
76
+ cls = next(iter(sorted(reg.items(), key=lambda kv: kv[0])))[1]
77
+ return cls(**kwargs)
78
+ if isinstance(loader, BaseDataLoader):
79
+ loader.kwargs = {**kwargs, **loader.kwargs}
80
+ return loader
81
+ if inspect.isclass(loader) and issubclass(loader, BaseDataLoader):
82
+ if not loader.__module__.startswith(self._scope_prefix):
83
+ raise DataLoaderError(f"Loader class must live under {self._scope_prefix!r}.")
84
+ return loader(**kwargs)
85
+ if isinstance(loader, str):
86
+ if loader not in reg:
87
+ raise DataLoaderError(f"Unknown loader {loader!r} in scope {self._scope_prefix!r}.")
88
+ return reg[loader](**kwargs)
89
+ raise TypeError("loader must be None, a name (str), BaseDataLoader instance, or BaseDataLoader subclass.")
90
+
91
+ def use(self, loader: Union[str, BaseDataLoader, Type[BaseDataLoader]], **kwargs: Any) -> None:
92
+ self._loader = self._resolve_loader(loader, **kwargs)
93
+
94
+ def load(self, **kwargs: Any) -> LoadDataResult:
95
+ merged = {**getattr(self._loader, "kwargs", {}), **kwargs}
96
+ return self._loader.load(**merged)
File without changes
@@ -0,0 +1,148 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+ import importlib.util
5
+ import pkgutil
6
+ from pathlib import Path
7
+ from typing import Any, Dict, List, Optional, Sequence, Union, Type
8
+ import inspect
9
+ import logging
10
+
11
+ from wisent.core.evaluators.core.atoms import BaseEvaluator, EvalResult, EvaluatorError
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class EvaluatorRotator:
17
+ """Orchestrates evaluator selection and execution with flexible discovery."""
18
+
19
+ def __init__(
20
+ self,
21
+ evaluator: Union[str, BaseEvaluator, Type[BaseEvaluator], None] = None,
22
+ task_name: Optional[str] = None,
23
+ evaluators_location: Union[str, Path] = "wisent_guard.core.evaluators.oracles",
24
+ autoload: bool = True,
25
+ ) -> None:
26
+ if autoload:
27
+ self.discover_evaluators(evaluators_location)
28
+ self._evaluator = self._resolve_evaluator(evaluator)
29
+ self._task_name = task_name
30
+
31
+ @staticmethod
32
+ def discover_evaluators(location: Union[str, Path] = "wisent_guard.core.evaluators.oracles") -> None:
33
+ """
34
+ Import all evaluator modules so BaseEvaluator subclasses self-register.
35
+
36
+ - If `location` is a dotted module path (str without existing FS path),
37
+ import that package and iterate its __path__ (works with namespace packages).
38
+ - If `location` is an existing directory (Path/str), import all .py files inside.
39
+ """
40
+
41
+ loc_path = Path(str(location))
42
+ if loc_path.exists() and loc_path.is_dir():
43
+ EvaluatorRotator._import_all_py_in_dir(loc_path)
44
+ return
45
+
46
+ if not isinstance(location, str):
47
+ raise EvaluatorError(
48
+ f"Invalid evaluators location: {location!r}. Provide a dotted module path or a directory."
49
+ )
50
+
51
+ try:
52
+ pkg = importlib.import_module(location)
53
+ except ModuleNotFoundError as exc:
54
+ raise EvaluatorError(
55
+ f"Cannot import evaluator package {location!r}. "
56
+ f"Use dotted path (no leading slash) and ensure your project root is on PYTHONPATH."
57
+ ) from exc
58
+
59
+ search_paths = list(getattr(pkg, "__path__", [])) # supports namespace pkgs
60
+ if not search_paths:
61
+ # Some packages may still have __file__ only
62
+ pkg_file = getattr(pkg, "__file__", None)
63
+ if pkg_file:
64
+ search_paths = [str(Path(pkg_file).parent)]
65
+
66
+ for finder, name, ispkg in pkgutil.iter_modules(search_paths):
67
+ if name.startswith("_"):
68
+ continue
69
+ importlib.import_module(f"{location}.{name}")
70
+
71
+ @staticmethod
72
+ def _import_all_py_in_dir(directory: Path) -> None:
73
+ for py in directory.glob("*.py"):
74
+ if py.name.startswith("_"):
75
+ continue
76
+ mod_name = f"_dyn_evaluators_{py.stem}"
77
+ spec = importlib.util.spec_from_file_location(mod_name, py)
78
+ if spec and spec.loader:
79
+ module = importlib.util.module_from_spec(spec)
80
+ spec.loader.exec_module(module) # type: ignore[attr-defined]
81
+
82
+ @staticmethod
83
+ def list_evaluators() -> List[Dict[str, Any]]:
84
+ out: List[Dict[str, Any]] = []
85
+ for name, cls in BaseEvaluator.list_registered().items():
86
+ out.append(
87
+ {
88
+ "name": name,
89
+ "description": getattr(cls, "description", ""),
90
+ "task_names": list(getattr(cls, "task_names", ())),
91
+ "class": f"{cls.__module__}.{cls.__name__}",
92
+ }
93
+ )
94
+ return sorted(out, key=lambda x: x["name"])
95
+
96
+ @staticmethod
97
+ def _resolve_evaluator(
98
+ evaluator: Union[str, BaseEvaluator, Type[BaseEvaluator], None]
99
+ ) -> BaseEvaluator:
100
+ if evaluator is None:
101
+ registry = BaseEvaluator.list_registered()
102
+ if "lm_eval" in registry:
103
+ return registry["lm_eval"]()
104
+ if registry:
105
+ return next(iter(registry.values()))()
106
+ raise EvaluatorError("No evaluators registered.")
107
+ if isinstance(evaluator, BaseEvaluator):
108
+ return evaluator
109
+ if inspect.isclass(evaluator) and issubclass(evaluator, BaseEvaluator):
110
+ return evaluator()
111
+ if isinstance(evaluator, str):
112
+ cls = BaseEvaluator.get(evaluator)
113
+ return cls()
114
+ raise TypeError(
115
+ "evaluator must be None, a name (str), BaseEvaluator instance, or BaseEvaluator subclass."
116
+ )
117
+
118
+ def use(self, evaluator: Union[str, BaseEvaluator, Type[BaseEvaluator]]) -> None:
119
+ self._evaluator = self._resolve_evaluator(evaluator)
120
+
121
+ def evaluate(self, response: str, expected: Any, **kwargs) -> EvalResult:
122
+ kwargs.setdefault("task_name", self._task_name)
123
+ return self._evaluator.evaluate(response, expected, **kwargs)
124
+
125
+ def evaluate_batch(
126
+ self, responses: Sequence[str], expected_answers: Sequence[Any], **kwargs
127
+ ) -> List[EvalResult]:
128
+ kwargs.setdefault("task_name", self._task_name)
129
+ return self._evaluator.evaluate_batch(responses, expected_answers, **kwargs)
130
+
131
+
132
+ if __name__ == "__main__":
133
+ from evaluator_rotator import EvaluatorRotator
134
+
135
+ rot = EvaluatorRotator(
136
+ evaluators_location="wisent_guard.core.evaluators.oracles", # << no leading slash
137
+ autoload=True,
138
+ )
139
+
140
+ rot.list_evaluators()
141
+ print("Available evaluators:")
142
+ for ev in rot.list_evaluators():
143
+ print(f" - {ev['name']}: {ev['description']} (tasks: {', '.join(ev['task_names'])})")
144
+
145
+ # rot.use("nlp")
146
+ # res = rot.evaluate("The answer is probably 42", expected="The answer is 12")
147
+
148
+ # print(res)
File without changes