wisent 0.5.11__py3-none-any.whl → 0.5.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (225) hide show
  1. wisent/__init__.py +1 -1
  2. wisent/core/activations/__init__.py +26 -0
  3. wisent/core/activations/activations.py +96 -0
  4. wisent/core/activations/activations_collector.py +71 -20
  5. wisent/core/activations/prompt_construction_strategy.py +47 -0
  6. wisent/core/agent/budget.py +2 -2
  7. wisent/core/agent/device_benchmarks.py +1 -1
  8. wisent/core/agent/diagnose/classifier_marketplace.py +8 -8
  9. wisent/core/agent/diagnose/response_diagnostics.py +4 -4
  10. wisent/core/agent/diagnose/synthetic_classifier_option.py +1 -1
  11. wisent/core/agent/diagnose/tasks/task_manager.py +3 -3
  12. wisent/core/agent/diagnose.py +2 -1
  13. wisent/core/autonomous_agent.py +10 -2
  14. wisent/core/benchmark_extractors.py +293 -0
  15. wisent/core/bigcode_integration.py +20 -7
  16. wisent/core/branding.py +108 -0
  17. wisent/core/cli/__init__.py +15 -0
  18. wisent/core/cli/create_steering_vector.py +138 -0
  19. wisent/core/cli/evaluate_responses.py +715 -0
  20. wisent/core/cli/generate_pairs.py +128 -0
  21. wisent/core/cli/generate_pairs_from_task.py +119 -0
  22. wisent/core/cli/generate_responses.py +129 -0
  23. wisent/core/cli/generate_vector_from_synthetic.py +149 -0
  24. wisent/core/cli/generate_vector_from_task.py +147 -0
  25. wisent/core/cli/get_activations.py +191 -0
  26. wisent/core/cli/optimize_classification.py +339 -0
  27. wisent/core/cli/optimize_steering.py +364 -0
  28. wisent/core/cli/tasks.py +182 -0
  29. wisent/core/cli_logger.py +22 -0
  30. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +27 -1
  31. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +49 -1
  32. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +115 -0
  33. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +115 -0
  34. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +111 -0
  35. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +119 -0
  36. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +115 -0
  37. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +114 -0
  38. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +118 -0
  39. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +146 -0
  40. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +129 -0
  41. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gsm8k.py +119 -0
  42. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +112 -0
  43. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +113 -0
  44. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/livecodebench.py +367 -0
  45. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +115 -0
  46. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +114 -0
  47. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +113 -0
  48. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +112 -0
  49. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +111 -0
  50. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +114 -0
  51. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +113 -0
  52. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +115 -0
  53. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +111 -0
  54. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +113 -0
  55. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +113 -0
  56. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +112 -0
  57. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +116 -0
  58. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +115 -0
  59. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +111 -0
  60. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +111 -0
  61. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +121 -0
  62. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +121 -0
  63. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +111 -0
  64. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +110 -0
  65. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +114 -0
  66. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/squad2.py +124 -0
  67. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sst2.py +111 -0
  68. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +112 -0
  69. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +127 -0
  70. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_gen.py +112 -0
  71. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +117 -0
  72. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +117 -0
  73. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +127 -0
  74. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +119 -0
  75. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +1 -1
  76. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +111 -0
  77. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +114 -0
  78. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +112 -0
  79. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +114 -0
  80. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +114 -0
  81. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +1 -1
  82. wisent/core/data_loaders/__init__.py +235 -0
  83. wisent/core/data_loaders/loaders/lm_loader.py +2 -2
  84. wisent/core/data_loaders/loaders/task_interface_loader.py +300 -0
  85. wisent/{cli/data_loaders/data_loader_rotator.py → core/data_loaders/rotator.py} +1 -1
  86. wisent/core/download_full_benchmarks.py +79 -2
  87. wisent/core/evaluators/benchmark_specific/__init__.py +26 -0
  88. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/evaluator.py +17 -17
  89. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/cpp_sanitizer.py +2 -2
  90. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/java_sanitizer.py +2 -2
  91. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/python_sanitizer.py +2 -2
  92. wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/__init__.py +3 -0
  93. wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/provider.py +305 -0
  94. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/runtime.py +36 -4
  95. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/entrypoint.py +2 -4
  96. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/recipes.py +1 -1
  97. wisent/core/evaluators/benchmark_specific/coding/solution_generator.py +258 -0
  98. wisent/core/evaluators/benchmark_specific/exact_match_evaluator.py +79 -0
  99. wisent/core/evaluators/benchmark_specific/f1_evaluator.py +101 -0
  100. wisent/core/evaluators/benchmark_specific/generation_evaluator.py +197 -0
  101. wisent/core/{log_likelihoods_evaluator.py → evaluators/benchmark_specific/log_likelihoods_evaluator.py} +10 -2
  102. wisent/core/evaluators/benchmark_specific/perplexity_evaluator.py +140 -0
  103. wisent/core/evaluators/benchmark_specific/personalization_evaluator.py +250 -0
  104. wisent/{cli/evaluators/evaluator_rotator.py → core/evaluators/rotator.py} +4 -4
  105. wisent/core/lm_eval_harness_ground_truth.py +3 -2
  106. wisent/core/main.py +57 -0
  107. wisent/core/model_persistence.py +2 -2
  108. wisent/core/models/wisent_model.py +8 -6
  109. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
  110. wisent/core/optuna/steering/steering_optimization.py +1 -1
  111. wisent/core/parser_arguments/__init__.py +10 -0
  112. wisent/core/parser_arguments/agent_parser.py +110 -0
  113. wisent/core/parser_arguments/configure_model_parser.py +7 -0
  114. wisent/core/parser_arguments/create_steering_vector_parser.py +59 -0
  115. wisent/core/parser_arguments/evaluate_parser.py +40 -0
  116. wisent/core/parser_arguments/evaluate_responses_parser.py +10 -0
  117. wisent/core/parser_arguments/full_optimize_parser.py +115 -0
  118. wisent/core/parser_arguments/generate_pairs_from_task_parser.py +33 -0
  119. wisent/core/parser_arguments/generate_pairs_parser.py +29 -0
  120. wisent/core/parser_arguments/generate_responses_parser.py +15 -0
  121. wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +127 -0
  122. wisent/core/parser_arguments/generate_vector_from_task_parser.py +127 -0
  123. wisent/core/parser_arguments/generate_vector_parser.py +90 -0
  124. wisent/core/parser_arguments/get_activations_parser.py +90 -0
  125. wisent/core/parser_arguments/main_parser.py +152 -0
  126. wisent/core/parser_arguments/model_config_parser.py +59 -0
  127. wisent/core/parser_arguments/monitor_parser.py +17 -0
  128. wisent/core/parser_arguments/multi_steer_parser.py +47 -0
  129. wisent/core/parser_arguments/optimize_classification_parser.py +67 -0
  130. wisent/core/parser_arguments/optimize_sample_size_parser.py +58 -0
  131. wisent/core/parser_arguments/optimize_steering_parser.py +147 -0
  132. wisent/core/parser_arguments/synthetic_parser.py +93 -0
  133. wisent/core/parser_arguments/tasks_parser.py +584 -0
  134. wisent/core/parser_arguments/test_nonsense_parser.py +26 -0
  135. wisent/core/parser_arguments/utils.py +111 -0
  136. wisent/core/prompts/core/prompt_formater.py +3 -3
  137. wisent/core/prompts/prompt_stratiegies/direct_completion.py +2 -0
  138. wisent/core/prompts/prompt_stratiegies/instruction_following.py +2 -0
  139. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +2 -0
  140. wisent/core/prompts/prompt_stratiegies/role_playing.py +2 -0
  141. wisent/{cli/steering_methods/steering_rotator.py → core/steering_methods/rotator.py} +4 -4
  142. wisent/core/steering_optimizer.py +45 -21
  143. wisent/{synthetic → core/synthetic}/cleaners/deduper_cleaner.py +3 -3
  144. wisent/{synthetic → core/synthetic}/cleaners/methods/base_dedupers.py +2 -2
  145. wisent/{synthetic → core/synthetic}/cleaners/methods/base_refusalers.py +1 -1
  146. wisent/{synthetic → core/synthetic}/cleaners/pairs_cleaner.py +5 -5
  147. wisent/{synthetic → core/synthetic}/cleaners/refusaler_cleaner.py +4 -4
  148. wisent/{synthetic → core/synthetic}/db_instructions/mini_dp.py +1 -1
  149. wisent/{synthetic → core/synthetic}/generators/diversities/methods/fast_diversity.py +1 -1
  150. wisent/{synthetic → core/synthetic}/generators/pairs_generator.py +38 -12
  151. wisent/core/tasks/livecodebench_task.py +4 -103
  152. wisent/core/timing_calibration.py +1 -1
  153. {wisent-0.5.11.dist-info → wisent-0.5.13.dist-info}/METADATA +3 -3
  154. wisent-0.5.13.dist-info/RECORD +294 -0
  155. wisent-0.5.13.dist-info/entry_points.txt +2 -0
  156. wisent/benchmarks/coding/providers/livecodebench/provider.py +0 -53
  157. wisent/classifiers/core/atoms.py +0 -747
  158. wisent/classifiers/models/logistic.py +0 -29
  159. wisent/classifiers/models/mlp.py +0 -47
  160. wisent/cli/classifiers/classifier_rotator.py +0 -137
  161. wisent/cli/cli_logger.py +0 -142
  162. wisent/cli/wisent_cli/commands/help_cmd.py +0 -52
  163. wisent/cli/wisent_cli/commands/listing.py +0 -154
  164. wisent/cli/wisent_cli/commands/train_cmd.py +0 -322
  165. wisent/cli/wisent_cli/main.py +0 -93
  166. wisent/cli/wisent_cli/shell.py +0 -80
  167. wisent/cli/wisent_cli/ui.py +0 -69
  168. wisent/cli/wisent_cli/util/aggregations.py +0 -43
  169. wisent/cli/wisent_cli/util/parsing.py +0 -126
  170. wisent/cli/wisent_cli/version.py +0 -4
  171. wisent/opti/methods/__init__.py +0 -0
  172. wisent/synthetic/__init__.py +0 -0
  173. wisent/synthetic/cleaners/__init__.py +0 -0
  174. wisent/synthetic/cleaners/core/__init__.py +0 -0
  175. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  176. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  177. wisent/synthetic/db_instructions/__init__.py +0 -0
  178. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  179. wisent/synthetic/generators/__init__.py +0 -0
  180. wisent/synthetic/generators/core/__init__.py +0 -0
  181. wisent/synthetic/generators/diversities/__init__.py +0 -0
  182. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  183. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  184. wisent-0.5.11.dist-info/RECORD +0 -220
  185. /wisent/{benchmarks → core/evaluators/benchmark_specific/coding}/__init__.py +0 -0
  186. /wisent/{benchmarks/coding → core/evaluators/benchmark_specific/coding/metrics}/__init__.py +0 -0
  187. /wisent/{benchmarks/coding/metrics → core/evaluators/benchmark_specific/coding/metrics/core}/__init__.py +0 -0
  188. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/core/atoms.py +0 -0
  189. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/passk.py +0 -0
  190. /wisent/{benchmarks/coding/metrics/core → core/evaluators/benchmark_specific/coding/output_sanitizer}/__init__.py +0 -0
  191. /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/output_sanitizer/core}/__init__.py +0 -0
  192. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/core/atoms.py +0 -0
  193. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/utils.py +0 -0
  194. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/__init__.py +0 -0
  195. /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/providers}/core/__init__.py +0 -0
  196. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/core/atoms.py +0 -0
  197. /wisent/{benchmarks/coding/providers/core → core/evaluators/benchmark_specific/coding/safe_docker}/__init__.py +0 -0
  198. /wisent/{benchmarks/coding/providers/livecodebench → core/evaluators/benchmark_specific/coding/safe_docker/core}/__init__.py +0 -0
  199. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/atoms.py +0 -0
  200. /wisent/{benchmarks/coding/safe_docker → core/opti}/__init__.py +0 -0
  201. /wisent/{benchmarks/coding/safe_docker → core/opti}/core/__init__.py +0 -0
  202. /wisent/{opti → core/opti}/core/atoms.py +0 -0
  203. /wisent/{classifiers → core/opti/methods}/__init__.py +0 -0
  204. /wisent/{opti → core/opti}/methods/opti_classificator.py +0 -0
  205. /wisent/{opti → core/opti}/methods/opti_steering.py +0 -0
  206. /wisent/{classifiers/core → core/synthetic}/__init__.py +0 -0
  207. /wisent/{classifiers/models → core/synthetic/cleaners}/__init__.py +0 -0
  208. /wisent/{cli → core/synthetic/cleaners/core}/__init__.py +0 -0
  209. /wisent/{synthetic → core/synthetic}/cleaners/core/atoms.py +0 -0
  210. /wisent/{cli/classifiers → core/synthetic/cleaners/methods}/__init__.py +0 -0
  211. /wisent/{cli/data_loaders → core/synthetic/cleaners/methods/core}/__init__.py +0 -0
  212. /wisent/{synthetic → core/synthetic}/cleaners/methods/core/atoms.py +0 -0
  213. /wisent/{cli/evaluators → core/synthetic/db_instructions}/__init__.py +0 -0
  214. /wisent/{cli/steering_methods → core/synthetic/db_instructions/core}/__init__.py +0 -0
  215. /wisent/{synthetic → core/synthetic}/db_instructions/core/atoms.py +0 -0
  216. /wisent/{cli/wisent_cli → core/synthetic/generators}/__init__.py +0 -0
  217. /wisent/{cli/wisent_cli/commands → core/synthetic/generators/core}/__init__.py +0 -0
  218. /wisent/{synthetic → core/synthetic}/generators/core/atoms.py +0 -0
  219. /wisent/{cli/wisent_cli/util → core/synthetic/generators/diversities}/__init__.py +0 -0
  220. /wisent/{opti → core/synthetic/generators/diversities/core}/__init__.py +0 -0
  221. /wisent/{synthetic → core/synthetic}/generators/diversities/core/core.py +0 -0
  222. /wisent/{opti/core → core/synthetic/generators/diversities/methods}/__init__.py +0 -0
  223. {wisent-0.5.11.dist-info → wisent-0.5.13.dist-info}/WHEEL +0 -0
  224. {wisent-0.5.11.dist-info → wisent-0.5.13.dist-info}/licenses/LICENSE +0 -0
  225. {wisent-0.5.11.dist-info → wisent-0.5.13.dist-info}/top_level.txt +0 -0
@@ -1,747 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import logging
4
- import os
5
- from abc import ABC, abstractmethod
6
- from dataclasses import dataclass, asdict
7
- from typing import Any, Callable
8
-
9
- import torch
10
- import torch.nn as nn
11
- import torch.optim as optim
12
- from torch.utils.data import DataLoader, TensorDataset, random_split
13
- import numpy as np
14
-
15
- from torch.nn.modules.loss import _Loss
16
-
17
- __all__ = [
18
- "ClassifierTrainConfig",
19
- "ClassifierMetrics",
20
- "ClassifierTrainReport",
21
- "ClassifierError",
22
- "BaseClassifier",
23
- ]
24
-
25
- logger = logging.getLogger(__name__)
26
-
27
-
28
- @dataclass(slots=True, frozen=True)
29
- class ClassifierTrainConfig:
30
- """
31
- Training configuration for classifiers.
32
-
33
- attributes:
34
- test_size:
35
- fraction of data to hold out for testing
36
- num_epochs:
37
- maximum number of training epochs
38
- batch_size:
39
- training batch size
40
- learning_rate:
41
- optimizer learning rate
42
- monitor:
43
- which metric to monitor for best epoch selection
44
- random_state:
45
- random seed for data shuffling and initialization
46
- """
47
- test_size: float = 0.2
48
- num_epochs: int = 50
49
- batch_size: int = 32
50
- learning_rate: float = 1e-3
51
- monitor: str = "accuracy"
52
- random_state: int = 42
53
-
54
- @dataclass(slots=True, frozen=True)
55
- class ClassifierMetrics:
56
- """
57
- Evaluation metrics for classifiers.
58
-
59
- attributes:
60
- accuracy: float
61
- Overall accuracy of predictions.
62
- precision: float
63
- Precision (positive predictive value).
64
- recall: float
65
- Recall (sensitivity).
66
- f1: float
67
- F1 score (harmonic mean of precision and recall).
68
- auc: float
69
- Area under the ROC curve.
70
- """
71
- accuracy: float
72
- precision: float
73
- recall: float
74
- f1: float
75
- auc: float
76
-
77
- @dataclass(slots=True, frozen=True)
78
- class ClassifierTrainReport:
79
- """
80
- Training report for classifiers.
81
-
82
- attributes:
83
- classifier_name: str
84
- Name of the classifier.
85
- input_dim: int
86
- Dimensionality of the input features.
87
- best_epoch: int
88
- Epoch number of the best model.
89
- epochs_ran: int
90
- Total number of epochs run.
91
- final: ClassifierMetrics
92
- Final evaluation metrics on the test set. It contains accuracy, precision, recall, f1, and auc.
93
- history: dict[str, list[float]]
94
-
95
-
96
- """
97
- classifier_name: str
98
- input_dim: int
99
- best_epoch: int
100
- epochs_ran: int
101
- final: ClassifierMetrics
102
- history: dict[str, list[float]]
103
-
104
- def asdict(self) -> dict[str, str | int | float | dict]:
105
- """
106
- Return a dictionary representation of the report.
107
-
108
- returns:
109
- A dictionary with all report fields, including nested metrics.
110
-
111
- example:
112
- >>> report.asdict()
113
- {
114
- "classifier_name": "mlp",
115
- "input_dim": 4,
116
- "best_epoch": 23,
117
- "epochs_ran": 30,
118
- "final": {
119
- "accuracy": 0.95,
120
- "precision": 0.96,
121
- "recall": 0.94,
122
- "f1": 0.95,
123
- "auc": 0.98
124
- },
125
- "history": {
126
- "train_loss": [...],
127
- "test_loss": [...],
128
- "accuracy": [...],
129
- "precision": [...],
130
- "recall": [...],
131
- "f1": [...],
132
- "auc": [...]
133
- }
134
- }
135
- """
136
- d = asdict(self); d["final"] = asdict(self.final); return d
137
-
138
- class ClassifierError(RuntimeError):
139
- pass
140
-
141
- class BaseClassifier(ABC):
142
- name: str = "base"
143
- description: str = "Abstract classifier"
144
-
145
- _REGISTRY: dict[str, type[BaseClassifier]] = {}
146
-
147
- model: nn.Module | None
148
- device: str
149
- dtype: torch.dtype
150
- threshold: float
151
-
152
- def __init_subclass__(cls, **kwargs) -> None:
153
- super().__init_subclass__(**kwargs)
154
- if cls is BaseClassifier:
155
- return
156
- if not getattr(cls, "name", None):
157
- raise TypeError("Classifier subclasses must define class attribute `name`.")
158
- if cls.name in BaseClassifier._REGISTRY:
159
- raise ValueError(f"Duplicate classifier name: {cls.name!r}")
160
- BaseClassifier._REGISTRY[cls.name] = cls
161
-
162
- def __init__(
163
- self,
164
- threshold: float = 0.5,
165
- device: str | None = None,
166
- dtype: torch.dtype = torch.float32,
167
- ) -> None:
168
- if not 0.0 <= threshold <= 1.0:
169
- raise ValueError("threshold must be in [0.0, 1.0]")
170
- self.threshold = threshold
171
- self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
172
- self.dtype = torch.float32 if self.device == "mps" else dtype
173
- self.model = None
174
-
175
- @abstractmethod
176
- def build_model(self, input_dim: int, **model_params: Any) -> nn.Module:
177
- """Return a torch.nn.Module that outputs P(y=1) ∈ [0,1]."""
178
- raise NotImplementedError
179
-
180
- def model_hyperparams(self) -> dict[str, Any]:
181
- return {}
182
-
183
- def fit(
184
- self,
185
- X,
186
- y,
187
- config: ClassifierTrainConfig | None = None,
188
- optimizer: str | optim.Optimizer | callable | None = None,
189
- lr: float | None = None,
190
- optimizer_kwargs: dict | None = None,
191
- criterion: nn.Module | str | None = None,
192
- on_epoch_end: Callable[[int, dict[str, float]], bool | None] | None = None,
193
- **model_params: Any,
194
- ) -> ClassifierTrainReport:
195
-
196
- #1 creating
197
- cfg = config or ClassifierTrainConfig()
198
- torch.manual_seed(cfg.random_state)
199
-
200
- #2 creating tensors
201
- X_tensor = self.to_2d_tensor(X, device=self.device, dtype=self.dtype)
202
- y_tensor = self.to_1d_tensor(y, device=self.device, dtype=self.dtype)
203
-
204
- #3 checking dimensions
205
- if X_tensor.shape[0] != y_tensor.shape[0]:
206
- raise ClassifierError(f"X and y length mismatch: {X_tensor.shape[0]} vs {y_tensor.shape[0]}")
207
-
208
- if self.model is None:
209
- input_dim = int(X_tensor.shape[1])
210
- self.model = self.build_model(input_dim, **model_params).to(self.device)
211
-
212
-
213
- # 4 creating dataloaders
214
- train_loader, test_loader = self._make_dataloaders(X_tensor, y_tensor, cfg)
215
-
216
- # 5 creating criterion and optimizer
217
- crit = self._make_criterion(criterion) if criterion is not None else self.configure_criterion()
218
- learn_rate = lr if lr is not None else cfg.learning_rate
219
- opt = self._make_optimizer(self.model, optimizer, learn_rate, optimizer_kwargs or {})
220
-
221
- # 6 training loop
222
- best_metric = float("-inf")
223
- best_state: dict[str, torch.Tensor] | None = None
224
-
225
- # 7 history
226
- history: dict[str, list[float]] = {
227
- "train_loss": [], "test_loss": [],
228
- "accuracy": [], "precision": [], "recall": [], "f1": [], "auc": [],
229
- }
230
-
231
- # 8 main loop
232
- for epoch in range(cfg.num_epochs):
233
- # one epoch
234
- train_loss = self._train_one_epoch(self.model, train_loader, opt, crit)
235
- test_loss, probs, labels = self._eval_one_epoch(self.model, test_loader, crit)
236
-
237
- preds = [1.0 if p >= self.threshold else 0.0 for p in probs]
238
- acc, prec, rec, f1 = self._basic_prf(preds, labels)
239
- auc = self._roc_auc(labels, probs)
240
-
241
- history["train_loss"].append(train_loss)
242
- history["test_loss"].append(test_loss)
243
- history["accuracy"].append(acc)
244
- history["precision"].append(prec)
245
- history["recall"].append(rec)
246
- history["f1"].append(f1)
247
- history["auc"].append(auc)
248
-
249
- # keep best checkpoint by cfg.monitor
250
- monitored = history[cfg.monitor][-1]
251
- if monitored > best_metric:
252
- best_metric = monitored
253
- best_state = {k: v.detach().clone() for k, v in self.model.state_dict().items()}
254
-
255
- # optional external observer/pruner
256
- if on_epoch_end is not None:
257
- stop = on_epoch_end(epoch, {k: history[k][-1] for k in history})
258
- if stop:
259
- break
260
-
261
- if (epoch == 0) or ((epoch + 1) % 10 == 0) or (epoch == cfg.num_epochs - 1):
262
- logger.info("[%s] epoch %d/%d train=%.4f test=%.4f acc=%.4f f1=%.4f",
263
- self.name, epoch + 1, cfg.num_epochs, train_loss, test_loss, acc, f1)
264
-
265
- if best_state is not None:
266
- self.model.load_state_dict(best_state)
267
-
268
- # final pass
269
- test_loss, probs, labels = self._eval_one_epoch(self.model, test_loader, crit)
270
- preds = [1.0 if p >= self.threshold else 0.0 for p in probs]
271
- acc, prec, rec, f1 = self._basic_prf(preds, labels)
272
- auc = self._roc_auc(labels, probs)
273
- final = ClassifierMetrics(acc, prec, rec, f1, auc)
274
-
275
- best_epoch = int(max(range(len(history[cfg.monitor])), key=history[cfg.monitor].__getitem__) + 1)
276
- return ClassifierTrainReport(
277
- classifier_name=self.name,
278
- input_dim=input_dim,
279
- best_epoch=best_epoch,
280
- epochs_ran=len(history["accuracy"]),
281
- final=final,
282
- history={k: [float(v) for v in vs] for k, vs in history.items()},
283
- )
284
-
285
- def _make_dataloaders(
286
- self,
287
- X: torch.Tensor | np.ndarray,
288
- y: torch.Tensor | np.ndarray,
289
- cfg: ClassifierTrainConfig,
290
- ) -> tuple[DataLoader, DataLoader]:
291
- """
292
- Split (X, y) into train/test using a seeded random split and wrap each in DataLoaders.
293
-
294
- arguments:
295
- X:
296
- 2D feature array or tensor.
297
- y:
298
- 1D label array or tensor.
299
- cfg:
300
- training configuration with test_size, batch_size, and random_state.
301
-
302
- returns:
303
- tuple of (train_dataloader, test_dataloader)
304
-
305
- example:
306
- >>> X = np.random.randn(100, 2).astype(np.float32)
307
- >>> print(X.shape)
308
- (100, 2)
309
- >>> print(X[0])
310
- [ 0.123 -1.456]
311
- >>> y = np.random.randint(0, 2, size=(100,)).astype(np.int64)
312
- >>> print(y.shape)
313
- (100,)
314
- >>> print(y[0])
315
- 1
316
- >>> cfg = ClassifierTrainConfig(test_size=0.2, batch_size=16, random_state=42)
317
- >>> train_loader, test_loader = self._make_dataloaders(X, y, cfg)
318
- >>> print(len(train_loader.dataset), len(test_loader.dataset))
319
- (80, 20)
320
- >>> xb, yb = next(iter(train_loader))
321
- >>> print(xb.shape, yb.shape)
322
- (16, 2) (16,)
323
- """
324
-
325
- if isinstance(X, np.ndarray): X = torch.from_numpy(X)
326
- if isinstance(y, np.ndarray): y = torch.from_numpy(y)
327
-
328
- ds = TensorDataset(X, y)
329
-
330
- if len(ds) < 2:
331
- return (
332
- DataLoader(ds, batch_size=cfg.batch_size, shuffle=True),
333
- DataLoader(ds, batch_size=cfg.batch_size, shuffle=False),
334
- )
335
-
336
- test_count = max(1, int(round(cfg.test_size * len(ds))))
337
- test_count = min(test_count, len(ds) - 1)
338
- train_count = len(ds) - test_count
339
-
340
- gen = torch.Generator().manual_seed(cfg.random_state)
341
- train_ds, test_ds = random_split(ds, [train_count, test_count], generator=gen)
342
-
343
- return (
344
- DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True),
345
- DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False),
346
- )
347
-
348
-
349
- def predict(self, X: torch.Tensor | np.ndarray) -> int | list[int]:
350
- """
351
- Predict class labels for the given input.
352
-
353
- arguments:
354
- X:
355
- 2D feature array or tensor.
356
-
357
- returns:
358
- predicted class label(s) as int or list of int.
359
-
360
- example:
361
- >>> X = np.random.randn(5, 2).astype(np.float32)
362
- >>> print(X)
363
- [[ 0.123 -1.456]
364
- [ 0.789 0.012]
365
- [-0.345 0.678]
366
- [ 1.234 -0.567]
367
- [-0.890 -1.234]]
368
- >>> preds = self.predict(X)
369
- >>> print(preds)
370
- [0, 1, 1, 0, 0]
371
- """
372
- self._require_model()
373
-
374
- X_tensor = self.to_2d_tensor(X, device=self.device, dtype=self.dtype)
375
-
376
- with torch.no_grad():
377
- probs = self._forward_probs(self.model, X_tensor).view(-1).cpu().tolist()
378
- preds = [1 if p >= self.threshold else 0 for p in probs]
379
- return preds[0] if len(preds) == 1 else preds
380
-
381
- def predict_proba(self, X: torch.Tensor | np.ndarray) -> float | list[float]:
382
- """
383
- Predict class probabilities for the given input.
384
-
385
- arguments:
386
- X: 2D feature array or tensor.
387
-
388
- returns:
389
- predicted class probability
390
-
391
- example:
392
- >>> X = np.random.randn(5, 2).astype(np.float32)
393
- >>> print(X)
394
- [[ 0.123 -1.456]
395
- [ 0.789 0.012]
396
- [-0.345 0.678]
397
- [ 1.234 -0.567]
398
- [-0.890 -1.234]]
399
- >>> probs = self.predict_proba(X)
400
- >>> print(probs)
401
- [0.23, 0.76, 0.54, 0.12, 0.34]
402
- """
403
- self._require_model()
404
-
405
- X_tensor = self.to_2d_tensor(X, device=self.device, dtype=self.dtype)
406
-
407
- with torch.no_grad():
408
- probs = self._forward_probs(self.model, X_tensor).view(-1).cpu().tolist()
409
- return probs[0] if len(probs) == 1 else probs
410
-
411
- def evaluate(self, X: torch.Tensor | np.ndarray, y: torch.Tensor | np.ndarray) -> dict[str, float]:
412
- """
413
- Evaluate the model on the given dataset and return metrics.
414
-
415
- arguments:
416
- X:
417
- 2D feature array or tensor.
418
- y:
419
- 1D label array or tensor.
420
-
421
- returns:
422
- dictionary of evaluation metrics.
423
-
424
- flow:
425
- >>> X = np.random.randn(2, 2).astype(np.float32)
426
- >>> y = np.random.randint(0, 2, size=(2,)).astype(np.int64)
427
- >>> print(X)
428
- [[ 0.123 -1.456]
429
- [ 0.789 0.012]]
430
- >>> print(y)
431
- [1, 0]
432
- >>> y_pred = self.predict(X)
433
- >>> print(y_pred)
434
- [0, 0]
435
- >>> y_prob = self.predict_proba(X)
436
- >>> print(y_prob)
437
- [0.34, 0.12]
438
- >>> metrics = self.evaluate(X, y)
439
- >>> print(metrics)
440
- {'accuracy': 0.5, ...}
441
- """
442
- y_pred = self.predict(X)
443
- y_prob = self.predict_proba(X)
444
- preds = [float(y_pred)] if isinstance(y_pred, int) else [float(v) for v in y_pred]
445
- probs = [float(y_prob)] if isinstance(y_prob, float) else [float(v) for v in y_prob]
446
- labels = y.detach().cpu().view(-1).tolist() if isinstance(y, torch.Tensor) else list(y)
447
- acc, prec, rec, f1 = self._basic_prf(preds, labels)
448
- auc = self._roc_auc(labels, probs)
449
- return {"accuracy": acc, "precision": prec, "recall": rec, "f1": f1, "auc": auc}
450
-
451
- def configure_criterion(self) -> nn.Module: return nn.BCELoss()
452
-
453
- def _make_criterion(self, spec: nn.Module | str) -> nn.Module:
454
- """
455
- Create a loss criterion from a string or module.
456
-
457
- arguments:
458
- spec:
459
- loss specification, either a string or a torch.nn.Module instance.
460
-
461
- returns:
462
- a torch.nn.Module loss function.
463
-
464
- raises:
465
- ValueError:
466
- if the string specification is unknown.
467
- """
468
- if isinstance(spec, nn.Module): return spec
469
- key = str(spec).strip().lower()
470
- if key in {"bce", "bceloss"}: return nn.BCELoss()
471
- if key in {"bcewithlogits", "bcewithlogitsloss"}: return nn.BCEWithLogitsLoss()
472
- raise ValueError(f"Unknown criterion: {spec!r}")
473
-
474
- def configure_optimizer(self, model: nn.Module, lr: float) -> optim.Optimizer:
475
- """
476
- Default optimizer configuration: Adam with given learning rate.
477
-
478
- arguments:
479
- model:
480
- the model to optimize.
481
- lr:
482
- the learning rate.
483
- returns:
484
- an Adam optimizer instance.
485
- """
486
- return optim.Adam(model.parameters(), lr=lr)
487
-
488
- def _make_optimizer(self, model: nn.Module, spec: str | optim.Optimizer | None, lr: float, extra: dict) -> optim.Optimizer:
489
- """
490
- Create an optimizer from a specification.
491
-
492
- arguments:
493
- model:
494
- the model to optimize.
495
- spec:
496
- optimizer specification: string, instance, callable, or None for default.
497
- lr:
498
- learning rate.
499
- extra:
500
- extra keyword arguments for optimizer constructor.
501
-
502
- returns:
503
- an optimizer instance.
504
-
505
- raises:
506
- ValueError:
507
- if the string specification is unknown.
508
- TypeError:
509
- if the specification type is unsupported.
510
- """
511
- if isinstance(spec, optim.Optimizer): return spec
512
- if spec is None: return self.configure_optimizer(model, lr)
513
- if isinstance(spec, str):
514
- try: cls = getattr(optim, spec)
515
- except AttributeError as exc: raise ValueError(f"Unknown optimizer: {spec!r}") from exc
516
- return cls(model.parameters(), lr=lr, **extra)
517
- if callable(spec): return spec(model.parameters(), lr=lr, **extra)
518
- raise TypeError(f"Unsupported optimizer spec: {type(spec)}")
519
-
520
- def _train_one_epoch(self, model: nn.Module, loader: DataLoader, optimizer: optim.Optimizer, criterion: _Loss) -> float:
521
- """
522
- Train the model for one epoch over the given DataLoader.
523
-
524
- arguments:
525
- model:
526
- the model to train.
527
- loader:
528
- DataLoader for training data.
529
- optimizer:
530
- optimizer instance.
531
- criterion:
532
- loss function.
533
-
534
- returns:
535
- average training loss over the epoch.
536
- """
537
- model.train(); total = 0.0; steps = 0
538
- xb: torch.Tensor; yb: torch.Tensor
539
-
540
- for xb, yb in loader:
541
- optimizer.zero_grad(set_to_none=True)
542
- out = self._forward_probs(model, xb)
543
- loss = criterion(out.view(-1), yb.view(-1))
544
- loss.backward(); optimizer.step()
545
- total += float(loss.item()); steps += 1
546
- return total / max(steps, 1)
547
-
548
- def _eval_one_epoch(self, model: nn.Module, loader: DataLoader, criterion: _Loss) -> float:
549
- """
550
- Evaluate the model for one epoch over the given DataLoader.
551
-
552
- arguments:
553
- model:
554
- the model to evaluate.
555
- loader:
556
- DataLoader for evaluation data.
557
- criterion:
558
- loss function.
559
-
560
- returns:
561
- average evaluation loss over the epoch.
562
- """
563
- model.eval(); total = 0.0; steps = 0; probs_all=[]; labels_all=[]
564
- with torch.no_grad():
565
- xb: torch.Tensor; yb: torch.Tensor
566
- for xb, yb in loader:
567
- out = self._forward_probs(model, xb)
568
- loss = criterion(out.view(-1), yb.view(-1))
569
- total += float(loss.item()); steps += 1
570
- probs_all.extend(out.detach().cpu().view(-1).tolist())
571
- labels_all.extend(yb.detach().cpu().view(-1).tolist())
572
- return (total / max(steps, 1), probs_all, labels_all)
573
-
574
- def _forward_probs(self, model: nn.Module, xb: torch.Tensor) -> torch.Tensor:
575
- """
576
- Forward pass to get predicted probabilities.
577
-
578
- arguments:
579
- model:
580
- the model to use.
581
- xb:
582
- input feature tensor.
583
-
584
- returns:
585
- tensor of predicted probabilities.
586
- """
587
- if xb.device.type != self.device: xb = xb.to(self.device)
588
- if xb.dtype != self.dtype: xb = xb.to(self.dtype)
589
- out = model(xb)
590
- return out.view(-1, 1) if out.ndim == 1 else out
591
-
592
- def save_model(self, path: str) -> None:
593
- """
594
- Save the model state and metadata to a file.
595
-
596
- arguments:
597
- path:
598
- the file path to save the model.
599
-
600
- raises:
601
- ClassifierError:
602
- if the model is not initialized."""
603
- self._require_model()
604
- os.makedirs(os.path.dirname(os.path.abspath(path)) or ".", exist_ok=True)
605
- input_dim = int(next(self.model.parameters()).shape[1])
606
- torch.save({
607
- "classifier_name": self.name,
608
- "state_dict": self.model.state_dict(),
609
- "input_dim": input_dim,
610
- "threshold": self.threshold,
611
- "model_hyperparams": self.model_hyperparams(),
612
- }, path)
613
- logger.info("Saved %s to %s", self.name, path)
614
-
615
- def load_model(self, path: str) -> None:
616
- """
617
- Load the model state and metadata from a file.
618
-
619
- arguments:
620
- path:
621
- the file path to load the model from.
622
-
623
- raises:
624
- FileNotFoundError:
625
- if the model file does not exist.
626
- ClassifierError:
627
- if the checkpoint format is unsupported.
628
- """
629
- if not os.path.exists(path): raise FileNotFoundError(path)
630
- data = torch.load(path, map_location=self.device, weights_only=False)
631
- if not isinstance(data, dict) or "state_dict" not in data or "input_dim" not in data:
632
- raise ClassifierError("Unsupported checkpoint format.")
633
- self.threshold = float(data.get("threshold", self.threshold))
634
- input_dim = int(data["input_dim"])
635
- hyper = dict(data.get("model_hyperparams", {}))
636
- self.model = self.build_model(input_dim, **hyper).to(self.device)
637
- self.model.load_state_dict(data["state_dict"]); self.model.eval()
638
-
639
- def _require_model(self) -> None:
640
- if self.model is None:
641
- raise ClassifierError("Model not initialized. Call fit() or load_model() first.")
642
-
643
- @classmethod
644
- def to_2d_tensor(cls, X, device: str, dtype: torch.dtype) -> torch.Tensor:
645
- """
646
- Convert input to a 2D tensor on the specified device and dtype.
647
-
648
- arguments:
649
- X:
650
- input data as array-like or tensor.
651
- device:
652
- target device string.
653
- dtype:
654
- target torch dtype.
655
-
656
- returns:
657
- 2D torch tensor.
658
-
659
- raises:
660
- ClassifierError:
661
- if the input cannot be converted to 2D tensor.
662
- """
663
- if isinstance(X, torch.Tensor):
664
- t = X.to(device=device, dtype=dtype)
665
- if t.ndim == 1: t = t.view(1, -1)
666
- if t.ndim != 2: raise ClassifierError(f"Expected 2D features, got {tuple(t.shape)}")
667
- return t
668
- t = torch.tensor(X, device=device, dtype=dtype)
669
- if t.ndim == 1: t = t.view(1, -1)
670
- if t.ndim != 2: raise ClassifierError(f"Expected 2D features, got {tuple(t.shape)}")
671
- return t
672
-
673
- @staticmethod
674
- def to_1d_tensor(y, *, device: str, dtype: torch.dtype) -> torch.Tensor:
675
- """
676
- Convert input to a 1D tensor on the specified device and dtype.
677
-
678
- arguments:
679
- y:
680
- input data as array-like or tensor.
681
- device:
682
- target device string.
683
- dtype:
684
- target torch dtype.
685
-
686
- returns:
687
- 1D torch tensor.
688
-
689
- raises:
690
- ClassifierError:
691
- if the input cannot be converted to 1D tensor.
692
- """
693
- if isinstance(y, torch.Tensor):
694
- return y.to(device=device, dtype=dtype).view(-1)
695
- return torch.tensor(list(y), device=device, dtype=dtype).view(-1)
696
-
697
- @staticmethod
698
- def _basic_prf(preds: list[float], labels: list[float]) -> tuple[float, float, float, float]:
699
- """
700
- Compute basic precision, recall, and F1 score.
701
-
702
- arguments:
703
- preds:
704
- list of predicted labels (0.0 or 1.0).
705
- labels:
706
- list of true labels (0.0 or 1.0).
707
-
708
- returns:
709
- tuple of (accuracy, precision, recall, f1).
710
- """
711
- tp = sum(1 for p, l in zip(preds, labels) if p == 1 and l == 1)
712
- fp = sum(1 for p, l in zip(preds, labels) if p == 1 and l == 0)
713
- fn = sum(1 for p, l in zip(preds, labels) if p == 0 and l == 1)
714
- total = max(len(labels), 1)
715
- acc = sum(1 for p, l in zip(preds, labels) if p == l) / total
716
- prec = tp / (tp + fp) if tp + fp > 0 else 0.0
717
- rec = tp / (tp + fn) if tp + fn > 0 else 0.0
718
- f1 = (2 * prec * rec / (prec + rec)) if (prec + rec) > 0 else 0.0
719
- return float(acc), float(prec), float(rec), float(f1)
720
-
721
- @staticmethod
722
- def _roc_auc(labels: list[float], scores: list[float]) -> float:
723
- """
724
- Compute ROC AUC using the Mann-Whitney U statistic.
725
-
726
- arguments:
727
- labels:
728
- list of true binary labels (0.0 or 1.0).
729
- scores:
730
- list of predicted scores or probabilities.
731
-
732
- returns:
733
- ROC AUC value.
734
- """
735
- if len(scores) < 2 or len(set(labels)) < 2: return 0.0
736
- pairs = sorted(zip(scores, labels), key=lambda x: x[0])
737
- pos = sum(1 for _, y in pairs if y == 1); neg = sum(1 for _, y in pairs if y == 0)
738
- if pos == 0 or neg == 0: return 0.0
739
- rank_sum = 0.0; i = 0
740
- while i < len(pairs):
741
- j = i
742
- while j + 1 < len(pairs) and pairs[j + 1][0] == pairs[i][0]: j += 1
743
- avg_rank = (i + j + 2) / 2.0
744
- rank_sum += avg_rank * sum(1 for k in range(i, j + 1) if pairs[k][1] == 1)
745
- i = j + 1
746
- U = rank_sum - pos * (pos + 1) / 2.0
747
- return float(U / (pos * neg))