wisent 0.5.11__py3-none-any.whl → 0.5.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -1
- wisent/core/activations/__init__.py +26 -0
- wisent/core/activations/activations.py +96 -0
- wisent/core/activations/activations_collector.py +71 -20
- wisent/core/activations/prompt_construction_strategy.py +47 -0
- wisent/core/agent/budget.py +2 -2
- wisent/core/agent/device_benchmarks.py +1 -1
- wisent/core/agent/diagnose/classifier_marketplace.py +8 -8
- wisent/core/agent/diagnose/response_diagnostics.py +4 -4
- wisent/core/agent/diagnose/synthetic_classifier_option.py +1 -1
- wisent/core/agent/diagnose/tasks/task_manager.py +3 -3
- wisent/core/agent/diagnose.py +2 -1
- wisent/core/autonomous_agent.py +10 -2
- wisent/core/benchmark_extractors.py +293 -0
- wisent/core/bigcode_integration.py +20 -7
- wisent/core/branding.py +108 -0
- wisent/core/cli/__init__.py +15 -0
- wisent/core/cli/create_steering_vector.py +138 -0
- wisent/core/cli/evaluate_responses.py +715 -0
- wisent/core/cli/generate_pairs.py +128 -0
- wisent/core/cli/generate_pairs_from_task.py +119 -0
- wisent/core/cli/generate_responses.py +129 -0
- wisent/core/cli/generate_vector_from_synthetic.py +149 -0
- wisent/core/cli/generate_vector_from_task.py +147 -0
- wisent/core/cli/get_activations.py +191 -0
- wisent/core/cli/optimize_classification.py +339 -0
- wisent/core/cli/optimize_steering.py +364 -0
- wisent/core/cli/tasks.py +182 -0
- wisent/core/cli_logger.py +22 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +27 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +49 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +118 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +146 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +129 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gsm8k.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/livecodebench.py +367 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +121 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +121 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +110 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/squad2.py +124 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sst2.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +127 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_gen.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +117 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +117 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +127 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +1 -1
- wisent/core/data_loaders/__init__.py +235 -0
- wisent/core/data_loaders/loaders/lm_loader.py +2 -2
- wisent/core/data_loaders/loaders/task_interface_loader.py +300 -0
- wisent/{cli/data_loaders/data_loader_rotator.py → core/data_loaders/rotator.py} +1 -1
- wisent/core/download_full_benchmarks.py +79 -2
- wisent/core/evaluators/benchmark_specific/__init__.py +26 -0
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/evaluator.py +17 -17
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/cpp_sanitizer.py +2 -2
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/java_sanitizer.py +2 -2
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/python_sanitizer.py +2 -2
- wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/__init__.py +3 -0
- wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/provider.py +305 -0
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/runtime.py +36 -4
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/entrypoint.py +2 -4
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/recipes.py +1 -1
- wisent/core/evaluators/benchmark_specific/coding/solution_generator.py +258 -0
- wisent/core/evaluators/benchmark_specific/exact_match_evaluator.py +79 -0
- wisent/core/evaluators/benchmark_specific/f1_evaluator.py +101 -0
- wisent/core/evaluators/benchmark_specific/generation_evaluator.py +197 -0
- wisent/core/{log_likelihoods_evaluator.py → evaluators/benchmark_specific/log_likelihoods_evaluator.py} +10 -2
- wisent/core/evaluators/benchmark_specific/perplexity_evaluator.py +140 -0
- wisent/core/evaluators/benchmark_specific/personalization_evaluator.py +250 -0
- wisent/{cli/evaluators/evaluator_rotator.py → core/evaluators/rotator.py} +4 -4
- wisent/core/lm_eval_harness_ground_truth.py +3 -2
- wisent/core/main.py +57 -0
- wisent/core/model_persistence.py +2 -2
- wisent/core/models/wisent_model.py +8 -6
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
- wisent/core/optuna/steering/steering_optimization.py +1 -1
- wisent/core/parser_arguments/__init__.py +10 -0
- wisent/core/parser_arguments/agent_parser.py +110 -0
- wisent/core/parser_arguments/configure_model_parser.py +7 -0
- wisent/core/parser_arguments/create_steering_vector_parser.py +59 -0
- wisent/core/parser_arguments/evaluate_parser.py +40 -0
- wisent/core/parser_arguments/evaluate_responses_parser.py +10 -0
- wisent/core/parser_arguments/full_optimize_parser.py +115 -0
- wisent/core/parser_arguments/generate_pairs_from_task_parser.py +33 -0
- wisent/core/parser_arguments/generate_pairs_parser.py +29 -0
- wisent/core/parser_arguments/generate_responses_parser.py +15 -0
- wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +127 -0
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +127 -0
- wisent/core/parser_arguments/generate_vector_parser.py +90 -0
- wisent/core/parser_arguments/get_activations_parser.py +90 -0
- wisent/core/parser_arguments/main_parser.py +152 -0
- wisent/core/parser_arguments/model_config_parser.py +59 -0
- wisent/core/parser_arguments/monitor_parser.py +17 -0
- wisent/core/parser_arguments/multi_steer_parser.py +47 -0
- wisent/core/parser_arguments/optimize_classification_parser.py +67 -0
- wisent/core/parser_arguments/optimize_sample_size_parser.py +58 -0
- wisent/core/parser_arguments/optimize_steering_parser.py +147 -0
- wisent/core/parser_arguments/synthetic_parser.py +93 -0
- wisent/core/parser_arguments/tasks_parser.py +584 -0
- wisent/core/parser_arguments/test_nonsense_parser.py +26 -0
- wisent/core/parser_arguments/utils.py +111 -0
- wisent/core/prompts/core/prompt_formater.py +3 -3
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +2 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +2 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +2 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +2 -0
- wisent/{cli/steering_methods/steering_rotator.py → core/steering_methods/rotator.py} +4 -4
- wisent/core/steering_optimizer.py +45 -21
- wisent/{synthetic → core/synthetic}/cleaners/deduper_cleaner.py +3 -3
- wisent/{synthetic → core/synthetic}/cleaners/methods/base_dedupers.py +2 -2
- wisent/{synthetic → core/synthetic}/cleaners/methods/base_refusalers.py +1 -1
- wisent/{synthetic → core/synthetic}/cleaners/pairs_cleaner.py +5 -5
- wisent/{synthetic → core/synthetic}/cleaners/refusaler_cleaner.py +4 -4
- wisent/{synthetic → core/synthetic}/db_instructions/mini_dp.py +1 -1
- wisent/{synthetic → core/synthetic}/generators/diversities/methods/fast_diversity.py +1 -1
- wisent/{synthetic → core/synthetic}/generators/pairs_generator.py +38 -12
- wisent/core/tasks/livecodebench_task.py +4 -103
- wisent/core/timing_calibration.py +1 -1
- {wisent-0.5.11.dist-info → wisent-0.5.13.dist-info}/METADATA +3 -3
- wisent-0.5.13.dist-info/RECORD +294 -0
- wisent-0.5.13.dist-info/entry_points.txt +2 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +0 -53
- wisent/classifiers/core/atoms.py +0 -747
- wisent/classifiers/models/logistic.py +0 -29
- wisent/classifiers/models/mlp.py +0 -47
- wisent/cli/classifiers/classifier_rotator.py +0 -137
- wisent/cli/cli_logger.py +0 -142
- wisent/cli/wisent_cli/commands/help_cmd.py +0 -52
- wisent/cli/wisent_cli/commands/listing.py +0 -154
- wisent/cli/wisent_cli/commands/train_cmd.py +0 -322
- wisent/cli/wisent_cli/main.py +0 -93
- wisent/cli/wisent_cli/shell.py +0 -80
- wisent/cli/wisent_cli/ui.py +0 -69
- wisent/cli/wisent_cli/util/aggregations.py +0 -43
- wisent/cli/wisent_cli/util/parsing.py +0 -126
- wisent/cli/wisent_cli/version.py +0 -4
- wisent/opti/methods/__init__.py +0 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent-0.5.11.dist-info/RECORD +0 -220
- /wisent/{benchmarks → core/evaluators/benchmark_specific/coding}/__init__.py +0 -0
- /wisent/{benchmarks/coding → core/evaluators/benchmark_specific/coding/metrics}/__init__.py +0 -0
- /wisent/{benchmarks/coding/metrics → core/evaluators/benchmark_specific/coding/metrics/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/core/atoms.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/passk.py +0 -0
- /wisent/{benchmarks/coding/metrics/core → core/evaluators/benchmark_specific/coding/output_sanitizer}/__init__.py +0 -0
- /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/output_sanitizer/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/core/atoms.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/utils.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/__init__.py +0 -0
- /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/providers}/core/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/core/atoms.py +0 -0
- /wisent/{benchmarks/coding/providers/core → core/evaluators/benchmark_specific/coding/safe_docker}/__init__.py +0 -0
- /wisent/{benchmarks/coding/providers/livecodebench → core/evaluators/benchmark_specific/coding/safe_docker/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/atoms.py +0 -0
- /wisent/{benchmarks/coding/safe_docker → core/opti}/__init__.py +0 -0
- /wisent/{benchmarks/coding/safe_docker → core/opti}/core/__init__.py +0 -0
- /wisent/{opti → core/opti}/core/atoms.py +0 -0
- /wisent/{classifiers → core/opti/methods}/__init__.py +0 -0
- /wisent/{opti → core/opti}/methods/opti_classificator.py +0 -0
- /wisent/{opti → core/opti}/methods/opti_steering.py +0 -0
- /wisent/{classifiers/core → core/synthetic}/__init__.py +0 -0
- /wisent/{classifiers/models → core/synthetic/cleaners}/__init__.py +0 -0
- /wisent/{cli → core/synthetic/cleaners/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/cleaners/core/atoms.py +0 -0
- /wisent/{cli/classifiers → core/synthetic/cleaners/methods}/__init__.py +0 -0
- /wisent/{cli/data_loaders → core/synthetic/cleaners/methods/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/cleaners/methods/core/atoms.py +0 -0
- /wisent/{cli/evaluators → core/synthetic/db_instructions}/__init__.py +0 -0
- /wisent/{cli/steering_methods → core/synthetic/db_instructions/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/db_instructions/core/atoms.py +0 -0
- /wisent/{cli/wisent_cli → core/synthetic/generators}/__init__.py +0 -0
- /wisent/{cli/wisent_cli/commands → core/synthetic/generators/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/generators/core/atoms.py +0 -0
- /wisent/{cli/wisent_cli/util → core/synthetic/generators/diversities}/__init__.py +0 -0
- /wisent/{opti → core/synthetic/generators/diversities/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/generators/diversities/core/core.py +0 -0
- /wisent/{opti/core → core/synthetic/generators/diversities/methods}/__init__.py +0 -0
- {wisent-0.5.11.dist-info → wisent-0.5.13.dist-info}/WHEEL +0 -0
- {wisent-0.5.11.dist-info → wisent-0.5.13.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.5.11.dist-info → wisent-0.5.13.dist-info}/top_level.txt +0 -0
wisent/classifiers/core/atoms.py
DELETED
|
@@ -1,747 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import logging
|
|
4
|
-
import os
|
|
5
|
-
from abc import ABC, abstractmethod
|
|
6
|
-
from dataclasses import dataclass, asdict
|
|
7
|
-
from typing import Any, Callable
|
|
8
|
-
|
|
9
|
-
import torch
|
|
10
|
-
import torch.nn as nn
|
|
11
|
-
import torch.optim as optim
|
|
12
|
-
from torch.utils.data import DataLoader, TensorDataset, random_split
|
|
13
|
-
import numpy as np
|
|
14
|
-
|
|
15
|
-
from torch.nn.modules.loss import _Loss
|
|
16
|
-
|
|
17
|
-
__all__ = [
|
|
18
|
-
"ClassifierTrainConfig",
|
|
19
|
-
"ClassifierMetrics",
|
|
20
|
-
"ClassifierTrainReport",
|
|
21
|
-
"ClassifierError",
|
|
22
|
-
"BaseClassifier",
|
|
23
|
-
]
|
|
24
|
-
|
|
25
|
-
logger = logging.getLogger(__name__)
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
@dataclass(slots=True, frozen=True)
|
|
29
|
-
class ClassifierTrainConfig:
|
|
30
|
-
"""
|
|
31
|
-
Training configuration for classifiers.
|
|
32
|
-
|
|
33
|
-
attributes:
|
|
34
|
-
test_size:
|
|
35
|
-
fraction of data to hold out for testing
|
|
36
|
-
num_epochs:
|
|
37
|
-
maximum number of training epochs
|
|
38
|
-
batch_size:
|
|
39
|
-
training batch size
|
|
40
|
-
learning_rate:
|
|
41
|
-
optimizer learning rate
|
|
42
|
-
monitor:
|
|
43
|
-
which metric to monitor for best epoch selection
|
|
44
|
-
random_state:
|
|
45
|
-
random seed for data shuffling and initialization
|
|
46
|
-
"""
|
|
47
|
-
test_size: float = 0.2
|
|
48
|
-
num_epochs: int = 50
|
|
49
|
-
batch_size: int = 32
|
|
50
|
-
learning_rate: float = 1e-3
|
|
51
|
-
monitor: str = "accuracy"
|
|
52
|
-
random_state: int = 42
|
|
53
|
-
|
|
54
|
-
@dataclass(slots=True, frozen=True)
|
|
55
|
-
class ClassifierMetrics:
|
|
56
|
-
"""
|
|
57
|
-
Evaluation metrics for classifiers.
|
|
58
|
-
|
|
59
|
-
attributes:
|
|
60
|
-
accuracy: float
|
|
61
|
-
Overall accuracy of predictions.
|
|
62
|
-
precision: float
|
|
63
|
-
Precision (positive predictive value).
|
|
64
|
-
recall: float
|
|
65
|
-
Recall (sensitivity).
|
|
66
|
-
f1: float
|
|
67
|
-
F1 score (harmonic mean of precision and recall).
|
|
68
|
-
auc: float
|
|
69
|
-
Area under the ROC curve.
|
|
70
|
-
"""
|
|
71
|
-
accuracy: float
|
|
72
|
-
precision: float
|
|
73
|
-
recall: float
|
|
74
|
-
f1: float
|
|
75
|
-
auc: float
|
|
76
|
-
|
|
77
|
-
@dataclass(slots=True, frozen=True)
|
|
78
|
-
class ClassifierTrainReport:
|
|
79
|
-
"""
|
|
80
|
-
Training report for classifiers.
|
|
81
|
-
|
|
82
|
-
attributes:
|
|
83
|
-
classifier_name: str
|
|
84
|
-
Name of the classifier.
|
|
85
|
-
input_dim: int
|
|
86
|
-
Dimensionality of the input features.
|
|
87
|
-
best_epoch: int
|
|
88
|
-
Epoch number of the best model.
|
|
89
|
-
epochs_ran: int
|
|
90
|
-
Total number of epochs run.
|
|
91
|
-
final: ClassifierMetrics
|
|
92
|
-
Final evaluation metrics on the test set. It contains accuracy, precision, recall, f1, and auc.
|
|
93
|
-
history: dict[str, list[float]]
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
"""
|
|
97
|
-
classifier_name: str
|
|
98
|
-
input_dim: int
|
|
99
|
-
best_epoch: int
|
|
100
|
-
epochs_ran: int
|
|
101
|
-
final: ClassifierMetrics
|
|
102
|
-
history: dict[str, list[float]]
|
|
103
|
-
|
|
104
|
-
def asdict(self) -> dict[str, str | int | float | dict]:
|
|
105
|
-
"""
|
|
106
|
-
Return a dictionary representation of the report.
|
|
107
|
-
|
|
108
|
-
returns:
|
|
109
|
-
A dictionary with all report fields, including nested metrics.
|
|
110
|
-
|
|
111
|
-
example:
|
|
112
|
-
>>> report.asdict()
|
|
113
|
-
{
|
|
114
|
-
"classifier_name": "mlp",
|
|
115
|
-
"input_dim": 4,
|
|
116
|
-
"best_epoch": 23,
|
|
117
|
-
"epochs_ran": 30,
|
|
118
|
-
"final": {
|
|
119
|
-
"accuracy": 0.95,
|
|
120
|
-
"precision": 0.96,
|
|
121
|
-
"recall": 0.94,
|
|
122
|
-
"f1": 0.95,
|
|
123
|
-
"auc": 0.98
|
|
124
|
-
},
|
|
125
|
-
"history": {
|
|
126
|
-
"train_loss": [...],
|
|
127
|
-
"test_loss": [...],
|
|
128
|
-
"accuracy": [...],
|
|
129
|
-
"precision": [...],
|
|
130
|
-
"recall": [...],
|
|
131
|
-
"f1": [...],
|
|
132
|
-
"auc": [...]
|
|
133
|
-
}
|
|
134
|
-
}
|
|
135
|
-
"""
|
|
136
|
-
d = asdict(self); d["final"] = asdict(self.final); return d
|
|
137
|
-
|
|
138
|
-
class ClassifierError(RuntimeError):
|
|
139
|
-
pass
|
|
140
|
-
|
|
141
|
-
class BaseClassifier(ABC):
|
|
142
|
-
name: str = "base"
|
|
143
|
-
description: str = "Abstract classifier"
|
|
144
|
-
|
|
145
|
-
_REGISTRY: dict[str, type[BaseClassifier]] = {}
|
|
146
|
-
|
|
147
|
-
model: nn.Module | None
|
|
148
|
-
device: str
|
|
149
|
-
dtype: torch.dtype
|
|
150
|
-
threshold: float
|
|
151
|
-
|
|
152
|
-
def __init_subclass__(cls, **kwargs) -> None:
|
|
153
|
-
super().__init_subclass__(**kwargs)
|
|
154
|
-
if cls is BaseClassifier:
|
|
155
|
-
return
|
|
156
|
-
if not getattr(cls, "name", None):
|
|
157
|
-
raise TypeError("Classifier subclasses must define class attribute `name`.")
|
|
158
|
-
if cls.name in BaseClassifier._REGISTRY:
|
|
159
|
-
raise ValueError(f"Duplicate classifier name: {cls.name!r}")
|
|
160
|
-
BaseClassifier._REGISTRY[cls.name] = cls
|
|
161
|
-
|
|
162
|
-
def __init__(
|
|
163
|
-
self,
|
|
164
|
-
threshold: float = 0.5,
|
|
165
|
-
device: str | None = None,
|
|
166
|
-
dtype: torch.dtype = torch.float32,
|
|
167
|
-
) -> None:
|
|
168
|
-
if not 0.0 <= threshold <= 1.0:
|
|
169
|
-
raise ValueError("threshold must be in [0.0, 1.0]")
|
|
170
|
-
self.threshold = threshold
|
|
171
|
-
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
172
|
-
self.dtype = torch.float32 if self.device == "mps" else dtype
|
|
173
|
-
self.model = None
|
|
174
|
-
|
|
175
|
-
@abstractmethod
|
|
176
|
-
def build_model(self, input_dim: int, **model_params: Any) -> nn.Module:
|
|
177
|
-
"""Return a torch.nn.Module that outputs P(y=1) ∈ [0,1]."""
|
|
178
|
-
raise NotImplementedError
|
|
179
|
-
|
|
180
|
-
def model_hyperparams(self) -> dict[str, Any]:
|
|
181
|
-
return {}
|
|
182
|
-
|
|
183
|
-
def fit(
|
|
184
|
-
self,
|
|
185
|
-
X,
|
|
186
|
-
y,
|
|
187
|
-
config: ClassifierTrainConfig | None = None,
|
|
188
|
-
optimizer: str | optim.Optimizer | callable | None = None,
|
|
189
|
-
lr: float | None = None,
|
|
190
|
-
optimizer_kwargs: dict | None = None,
|
|
191
|
-
criterion: nn.Module | str | None = None,
|
|
192
|
-
on_epoch_end: Callable[[int, dict[str, float]], bool | None] | None = None,
|
|
193
|
-
**model_params: Any,
|
|
194
|
-
) -> ClassifierTrainReport:
|
|
195
|
-
|
|
196
|
-
#1 creating
|
|
197
|
-
cfg = config or ClassifierTrainConfig()
|
|
198
|
-
torch.manual_seed(cfg.random_state)
|
|
199
|
-
|
|
200
|
-
#2 creating tensors
|
|
201
|
-
X_tensor = self.to_2d_tensor(X, device=self.device, dtype=self.dtype)
|
|
202
|
-
y_tensor = self.to_1d_tensor(y, device=self.device, dtype=self.dtype)
|
|
203
|
-
|
|
204
|
-
#3 checking dimensions
|
|
205
|
-
if X_tensor.shape[0] != y_tensor.shape[0]:
|
|
206
|
-
raise ClassifierError(f"X and y length mismatch: {X_tensor.shape[0]} vs {y_tensor.shape[0]}")
|
|
207
|
-
|
|
208
|
-
if self.model is None:
|
|
209
|
-
input_dim = int(X_tensor.shape[1])
|
|
210
|
-
self.model = self.build_model(input_dim, **model_params).to(self.device)
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
# 4 creating dataloaders
|
|
214
|
-
train_loader, test_loader = self._make_dataloaders(X_tensor, y_tensor, cfg)
|
|
215
|
-
|
|
216
|
-
# 5 creating criterion and optimizer
|
|
217
|
-
crit = self._make_criterion(criterion) if criterion is not None else self.configure_criterion()
|
|
218
|
-
learn_rate = lr if lr is not None else cfg.learning_rate
|
|
219
|
-
opt = self._make_optimizer(self.model, optimizer, learn_rate, optimizer_kwargs or {})
|
|
220
|
-
|
|
221
|
-
# 6 training loop
|
|
222
|
-
best_metric = float("-inf")
|
|
223
|
-
best_state: dict[str, torch.Tensor] | None = None
|
|
224
|
-
|
|
225
|
-
# 7 history
|
|
226
|
-
history: dict[str, list[float]] = {
|
|
227
|
-
"train_loss": [], "test_loss": [],
|
|
228
|
-
"accuracy": [], "precision": [], "recall": [], "f1": [], "auc": [],
|
|
229
|
-
}
|
|
230
|
-
|
|
231
|
-
# 8 main loop
|
|
232
|
-
for epoch in range(cfg.num_epochs):
|
|
233
|
-
# one epoch
|
|
234
|
-
train_loss = self._train_one_epoch(self.model, train_loader, opt, crit)
|
|
235
|
-
test_loss, probs, labels = self._eval_one_epoch(self.model, test_loader, crit)
|
|
236
|
-
|
|
237
|
-
preds = [1.0 if p >= self.threshold else 0.0 for p in probs]
|
|
238
|
-
acc, prec, rec, f1 = self._basic_prf(preds, labels)
|
|
239
|
-
auc = self._roc_auc(labels, probs)
|
|
240
|
-
|
|
241
|
-
history["train_loss"].append(train_loss)
|
|
242
|
-
history["test_loss"].append(test_loss)
|
|
243
|
-
history["accuracy"].append(acc)
|
|
244
|
-
history["precision"].append(prec)
|
|
245
|
-
history["recall"].append(rec)
|
|
246
|
-
history["f1"].append(f1)
|
|
247
|
-
history["auc"].append(auc)
|
|
248
|
-
|
|
249
|
-
# keep best checkpoint by cfg.monitor
|
|
250
|
-
monitored = history[cfg.monitor][-1]
|
|
251
|
-
if monitored > best_metric:
|
|
252
|
-
best_metric = monitored
|
|
253
|
-
best_state = {k: v.detach().clone() for k, v in self.model.state_dict().items()}
|
|
254
|
-
|
|
255
|
-
# optional external observer/pruner
|
|
256
|
-
if on_epoch_end is not None:
|
|
257
|
-
stop = on_epoch_end(epoch, {k: history[k][-1] for k in history})
|
|
258
|
-
if stop:
|
|
259
|
-
break
|
|
260
|
-
|
|
261
|
-
if (epoch == 0) or ((epoch + 1) % 10 == 0) or (epoch == cfg.num_epochs - 1):
|
|
262
|
-
logger.info("[%s] epoch %d/%d train=%.4f test=%.4f acc=%.4f f1=%.4f",
|
|
263
|
-
self.name, epoch + 1, cfg.num_epochs, train_loss, test_loss, acc, f1)
|
|
264
|
-
|
|
265
|
-
if best_state is not None:
|
|
266
|
-
self.model.load_state_dict(best_state)
|
|
267
|
-
|
|
268
|
-
# final pass
|
|
269
|
-
test_loss, probs, labels = self._eval_one_epoch(self.model, test_loader, crit)
|
|
270
|
-
preds = [1.0 if p >= self.threshold else 0.0 for p in probs]
|
|
271
|
-
acc, prec, rec, f1 = self._basic_prf(preds, labels)
|
|
272
|
-
auc = self._roc_auc(labels, probs)
|
|
273
|
-
final = ClassifierMetrics(acc, prec, rec, f1, auc)
|
|
274
|
-
|
|
275
|
-
best_epoch = int(max(range(len(history[cfg.monitor])), key=history[cfg.monitor].__getitem__) + 1)
|
|
276
|
-
return ClassifierTrainReport(
|
|
277
|
-
classifier_name=self.name,
|
|
278
|
-
input_dim=input_dim,
|
|
279
|
-
best_epoch=best_epoch,
|
|
280
|
-
epochs_ran=len(history["accuracy"]),
|
|
281
|
-
final=final,
|
|
282
|
-
history={k: [float(v) for v in vs] for k, vs in history.items()},
|
|
283
|
-
)
|
|
284
|
-
|
|
285
|
-
def _make_dataloaders(
|
|
286
|
-
self,
|
|
287
|
-
X: torch.Tensor | np.ndarray,
|
|
288
|
-
y: torch.Tensor | np.ndarray,
|
|
289
|
-
cfg: ClassifierTrainConfig,
|
|
290
|
-
) -> tuple[DataLoader, DataLoader]:
|
|
291
|
-
"""
|
|
292
|
-
Split (X, y) into train/test using a seeded random split and wrap each in DataLoaders.
|
|
293
|
-
|
|
294
|
-
arguments:
|
|
295
|
-
X:
|
|
296
|
-
2D feature array or tensor.
|
|
297
|
-
y:
|
|
298
|
-
1D label array or tensor.
|
|
299
|
-
cfg:
|
|
300
|
-
training configuration with test_size, batch_size, and random_state.
|
|
301
|
-
|
|
302
|
-
returns:
|
|
303
|
-
tuple of (train_dataloader, test_dataloader)
|
|
304
|
-
|
|
305
|
-
example:
|
|
306
|
-
>>> X = np.random.randn(100, 2).astype(np.float32)
|
|
307
|
-
>>> print(X.shape)
|
|
308
|
-
(100, 2)
|
|
309
|
-
>>> print(X[0])
|
|
310
|
-
[ 0.123 -1.456]
|
|
311
|
-
>>> y = np.random.randint(0, 2, size=(100,)).astype(np.int64)
|
|
312
|
-
>>> print(y.shape)
|
|
313
|
-
(100,)
|
|
314
|
-
>>> print(y[0])
|
|
315
|
-
1
|
|
316
|
-
>>> cfg = ClassifierTrainConfig(test_size=0.2, batch_size=16, random_state=42)
|
|
317
|
-
>>> train_loader, test_loader = self._make_dataloaders(X, y, cfg)
|
|
318
|
-
>>> print(len(train_loader.dataset), len(test_loader.dataset))
|
|
319
|
-
(80, 20)
|
|
320
|
-
>>> xb, yb = next(iter(train_loader))
|
|
321
|
-
>>> print(xb.shape, yb.shape)
|
|
322
|
-
(16, 2) (16,)
|
|
323
|
-
"""
|
|
324
|
-
|
|
325
|
-
if isinstance(X, np.ndarray): X = torch.from_numpy(X)
|
|
326
|
-
if isinstance(y, np.ndarray): y = torch.from_numpy(y)
|
|
327
|
-
|
|
328
|
-
ds = TensorDataset(X, y)
|
|
329
|
-
|
|
330
|
-
if len(ds) < 2:
|
|
331
|
-
return (
|
|
332
|
-
DataLoader(ds, batch_size=cfg.batch_size, shuffle=True),
|
|
333
|
-
DataLoader(ds, batch_size=cfg.batch_size, shuffle=False),
|
|
334
|
-
)
|
|
335
|
-
|
|
336
|
-
test_count = max(1, int(round(cfg.test_size * len(ds))))
|
|
337
|
-
test_count = min(test_count, len(ds) - 1)
|
|
338
|
-
train_count = len(ds) - test_count
|
|
339
|
-
|
|
340
|
-
gen = torch.Generator().manual_seed(cfg.random_state)
|
|
341
|
-
train_ds, test_ds = random_split(ds, [train_count, test_count], generator=gen)
|
|
342
|
-
|
|
343
|
-
return (
|
|
344
|
-
DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True),
|
|
345
|
-
DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False),
|
|
346
|
-
)
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
def predict(self, X: torch.Tensor | np.ndarray) -> int | list[int]:
|
|
350
|
-
"""
|
|
351
|
-
Predict class labels for the given input.
|
|
352
|
-
|
|
353
|
-
arguments:
|
|
354
|
-
X:
|
|
355
|
-
2D feature array or tensor.
|
|
356
|
-
|
|
357
|
-
returns:
|
|
358
|
-
predicted class label(s) as int or list of int.
|
|
359
|
-
|
|
360
|
-
example:
|
|
361
|
-
>>> X = np.random.randn(5, 2).astype(np.float32)
|
|
362
|
-
>>> print(X)
|
|
363
|
-
[[ 0.123 -1.456]
|
|
364
|
-
[ 0.789 0.012]
|
|
365
|
-
[-0.345 0.678]
|
|
366
|
-
[ 1.234 -0.567]
|
|
367
|
-
[-0.890 -1.234]]
|
|
368
|
-
>>> preds = self.predict(X)
|
|
369
|
-
>>> print(preds)
|
|
370
|
-
[0, 1, 1, 0, 0]
|
|
371
|
-
"""
|
|
372
|
-
self._require_model()
|
|
373
|
-
|
|
374
|
-
X_tensor = self.to_2d_tensor(X, device=self.device, dtype=self.dtype)
|
|
375
|
-
|
|
376
|
-
with torch.no_grad():
|
|
377
|
-
probs = self._forward_probs(self.model, X_tensor).view(-1).cpu().tolist()
|
|
378
|
-
preds = [1 if p >= self.threshold else 0 for p in probs]
|
|
379
|
-
return preds[0] if len(preds) == 1 else preds
|
|
380
|
-
|
|
381
|
-
def predict_proba(self, X: torch.Tensor | np.ndarray) -> float | list[float]:
|
|
382
|
-
"""
|
|
383
|
-
Predict class probabilities for the given input.
|
|
384
|
-
|
|
385
|
-
arguments:
|
|
386
|
-
X: 2D feature array or tensor.
|
|
387
|
-
|
|
388
|
-
returns:
|
|
389
|
-
predicted class probability
|
|
390
|
-
|
|
391
|
-
example:
|
|
392
|
-
>>> X = np.random.randn(5, 2).astype(np.float32)
|
|
393
|
-
>>> print(X)
|
|
394
|
-
[[ 0.123 -1.456]
|
|
395
|
-
[ 0.789 0.012]
|
|
396
|
-
[-0.345 0.678]
|
|
397
|
-
[ 1.234 -0.567]
|
|
398
|
-
[-0.890 -1.234]]
|
|
399
|
-
>>> probs = self.predict_proba(X)
|
|
400
|
-
>>> print(probs)
|
|
401
|
-
[0.23, 0.76, 0.54, 0.12, 0.34]
|
|
402
|
-
"""
|
|
403
|
-
self._require_model()
|
|
404
|
-
|
|
405
|
-
X_tensor = self.to_2d_tensor(X, device=self.device, dtype=self.dtype)
|
|
406
|
-
|
|
407
|
-
with torch.no_grad():
|
|
408
|
-
probs = self._forward_probs(self.model, X_tensor).view(-1).cpu().tolist()
|
|
409
|
-
return probs[0] if len(probs) == 1 else probs
|
|
410
|
-
|
|
411
|
-
def evaluate(self, X: torch.Tensor | np.ndarray, y: torch.Tensor | np.ndarray) -> dict[str, float]:
|
|
412
|
-
"""
|
|
413
|
-
Evaluate the model on the given dataset and return metrics.
|
|
414
|
-
|
|
415
|
-
arguments:
|
|
416
|
-
X:
|
|
417
|
-
2D feature array or tensor.
|
|
418
|
-
y:
|
|
419
|
-
1D label array or tensor.
|
|
420
|
-
|
|
421
|
-
returns:
|
|
422
|
-
dictionary of evaluation metrics.
|
|
423
|
-
|
|
424
|
-
flow:
|
|
425
|
-
>>> X = np.random.randn(2, 2).astype(np.float32)
|
|
426
|
-
>>> y = np.random.randint(0, 2, size=(2,)).astype(np.int64)
|
|
427
|
-
>>> print(X)
|
|
428
|
-
[[ 0.123 -1.456]
|
|
429
|
-
[ 0.789 0.012]]
|
|
430
|
-
>>> print(y)
|
|
431
|
-
[1, 0]
|
|
432
|
-
>>> y_pred = self.predict(X)
|
|
433
|
-
>>> print(y_pred)
|
|
434
|
-
[0, 0]
|
|
435
|
-
>>> y_prob = self.predict_proba(X)
|
|
436
|
-
>>> print(y_prob)
|
|
437
|
-
[0.34, 0.12]
|
|
438
|
-
>>> metrics = self.evaluate(X, y)
|
|
439
|
-
>>> print(metrics)
|
|
440
|
-
{'accuracy': 0.5, ...}
|
|
441
|
-
"""
|
|
442
|
-
y_pred = self.predict(X)
|
|
443
|
-
y_prob = self.predict_proba(X)
|
|
444
|
-
preds = [float(y_pred)] if isinstance(y_pred, int) else [float(v) for v in y_pred]
|
|
445
|
-
probs = [float(y_prob)] if isinstance(y_prob, float) else [float(v) for v in y_prob]
|
|
446
|
-
labels = y.detach().cpu().view(-1).tolist() if isinstance(y, torch.Tensor) else list(y)
|
|
447
|
-
acc, prec, rec, f1 = self._basic_prf(preds, labels)
|
|
448
|
-
auc = self._roc_auc(labels, probs)
|
|
449
|
-
return {"accuracy": acc, "precision": prec, "recall": rec, "f1": f1, "auc": auc}
|
|
450
|
-
|
|
451
|
-
def configure_criterion(self) -> nn.Module: return nn.BCELoss()
|
|
452
|
-
|
|
453
|
-
def _make_criterion(self, spec: nn.Module | str) -> nn.Module:
|
|
454
|
-
"""
|
|
455
|
-
Create a loss criterion from a string or module.
|
|
456
|
-
|
|
457
|
-
arguments:
|
|
458
|
-
spec:
|
|
459
|
-
loss specification, either a string or a torch.nn.Module instance.
|
|
460
|
-
|
|
461
|
-
returns:
|
|
462
|
-
a torch.nn.Module loss function.
|
|
463
|
-
|
|
464
|
-
raises:
|
|
465
|
-
ValueError:
|
|
466
|
-
if the string specification is unknown.
|
|
467
|
-
"""
|
|
468
|
-
if isinstance(spec, nn.Module): return spec
|
|
469
|
-
key = str(spec).strip().lower()
|
|
470
|
-
if key in {"bce", "bceloss"}: return nn.BCELoss()
|
|
471
|
-
if key in {"bcewithlogits", "bcewithlogitsloss"}: return nn.BCEWithLogitsLoss()
|
|
472
|
-
raise ValueError(f"Unknown criterion: {spec!r}")
|
|
473
|
-
|
|
474
|
-
def configure_optimizer(self, model: nn.Module, lr: float) -> optim.Optimizer:
|
|
475
|
-
"""
|
|
476
|
-
Default optimizer configuration: Adam with given learning rate.
|
|
477
|
-
|
|
478
|
-
arguments:
|
|
479
|
-
model:
|
|
480
|
-
the model to optimize.
|
|
481
|
-
lr:
|
|
482
|
-
the learning rate.
|
|
483
|
-
returns:
|
|
484
|
-
an Adam optimizer instance.
|
|
485
|
-
"""
|
|
486
|
-
return optim.Adam(model.parameters(), lr=lr)
|
|
487
|
-
|
|
488
|
-
def _make_optimizer(self, model: nn.Module, spec: str | optim.Optimizer | None, lr: float, extra: dict) -> optim.Optimizer:
|
|
489
|
-
"""
|
|
490
|
-
Create an optimizer from a specification.
|
|
491
|
-
|
|
492
|
-
arguments:
|
|
493
|
-
model:
|
|
494
|
-
the model to optimize.
|
|
495
|
-
spec:
|
|
496
|
-
optimizer specification: string, instance, callable, or None for default.
|
|
497
|
-
lr:
|
|
498
|
-
learning rate.
|
|
499
|
-
extra:
|
|
500
|
-
extra keyword arguments for optimizer constructor.
|
|
501
|
-
|
|
502
|
-
returns:
|
|
503
|
-
an optimizer instance.
|
|
504
|
-
|
|
505
|
-
raises:
|
|
506
|
-
ValueError:
|
|
507
|
-
if the string specification is unknown.
|
|
508
|
-
TypeError:
|
|
509
|
-
if the specification type is unsupported.
|
|
510
|
-
"""
|
|
511
|
-
if isinstance(spec, optim.Optimizer): return spec
|
|
512
|
-
if spec is None: return self.configure_optimizer(model, lr)
|
|
513
|
-
if isinstance(spec, str):
|
|
514
|
-
try: cls = getattr(optim, spec)
|
|
515
|
-
except AttributeError as exc: raise ValueError(f"Unknown optimizer: {spec!r}") from exc
|
|
516
|
-
return cls(model.parameters(), lr=lr, **extra)
|
|
517
|
-
if callable(spec): return spec(model.parameters(), lr=lr, **extra)
|
|
518
|
-
raise TypeError(f"Unsupported optimizer spec: {type(spec)}")
|
|
519
|
-
|
|
520
|
-
def _train_one_epoch(self, model: nn.Module, loader: DataLoader, optimizer: optim.Optimizer, criterion: _Loss) -> float:
|
|
521
|
-
"""
|
|
522
|
-
Train the model for one epoch over the given DataLoader.
|
|
523
|
-
|
|
524
|
-
arguments:
|
|
525
|
-
model:
|
|
526
|
-
the model to train.
|
|
527
|
-
loader:
|
|
528
|
-
DataLoader for training data.
|
|
529
|
-
optimizer:
|
|
530
|
-
optimizer instance.
|
|
531
|
-
criterion:
|
|
532
|
-
loss function.
|
|
533
|
-
|
|
534
|
-
returns:
|
|
535
|
-
average training loss over the epoch.
|
|
536
|
-
"""
|
|
537
|
-
model.train(); total = 0.0; steps = 0
|
|
538
|
-
xb: torch.Tensor; yb: torch.Tensor
|
|
539
|
-
|
|
540
|
-
for xb, yb in loader:
|
|
541
|
-
optimizer.zero_grad(set_to_none=True)
|
|
542
|
-
out = self._forward_probs(model, xb)
|
|
543
|
-
loss = criterion(out.view(-1), yb.view(-1))
|
|
544
|
-
loss.backward(); optimizer.step()
|
|
545
|
-
total += float(loss.item()); steps += 1
|
|
546
|
-
return total / max(steps, 1)
|
|
547
|
-
|
|
548
|
-
def _eval_one_epoch(self, model: nn.Module, loader: DataLoader, criterion: _Loss) -> float:
|
|
549
|
-
"""
|
|
550
|
-
Evaluate the model for one epoch over the given DataLoader.
|
|
551
|
-
|
|
552
|
-
arguments:
|
|
553
|
-
model:
|
|
554
|
-
the model to evaluate.
|
|
555
|
-
loader:
|
|
556
|
-
DataLoader for evaluation data.
|
|
557
|
-
criterion:
|
|
558
|
-
loss function.
|
|
559
|
-
|
|
560
|
-
returns:
|
|
561
|
-
average evaluation loss over the epoch.
|
|
562
|
-
"""
|
|
563
|
-
model.eval(); total = 0.0; steps = 0; probs_all=[]; labels_all=[]
|
|
564
|
-
with torch.no_grad():
|
|
565
|
-
xb: torch.Tensor; yb: torch.Tensor
|
|
566
|
-
for xb, yb in loader:
|
|
567
|
-
out = self._forward_probs(model, xb)
|
|
568
|
-
loss = criterion(out.view(-1), yb.view(-1))
|
|
569
|
-
total += float(loss.item()); steps += 1
|
|
570
|
-
probs_all.extend(out.detach().cpu().view(-1).tolist())
|
|
571
|
-
labels_all.extend(yb.detach().cpu().view(-1).tolist())
|
|
572
|
-
return (total / max(steps, 1), probs_all, labels_all)
|
|
573
|
-
|
|
574
|
-
def _forward_probs(self, model: nn.Module, xb: torch.Tensor) -> torch.Tensor:
|
|
575
|
-
"""
|
|
576
|
-
Forward pass to get predicted probabilities.
|
|
577
|
-
|
|
578
|
-
arguments:
|
|
579
|
-
model:
|
|
580
|
-
the model to use.
|
|
581
|
-
xb:
|
|
582
|
-
input feature tensor.
|
|
583
|
-
|
|
584
|
-
returns:
|
|
585
|
-
tensor of predicted probabilities.
|
|
586
|
-
"""
|
|
587
|
-
if xb.device.type != self.device: xb = xb.to(self.device)
|
|
588
|
-
if xb.dtype != self.dtype: xb = xb.to(self.dtype)
|
|
589
|
-
out = model(xb)
|
|
590
|
-
return out.view(-1, 1) if out.ndim == 1 else out
|
|
591
|
-
|
|
592
|
-
def save_model(self, path: str) -> None:
|
|
593
|
-
"""
|
|
594
|
-
Save the model state and metadata to a file.
|
|
595
|
-
|
|
596
|
-
arguments:
|
|
597
|
-
path:
|
|
598
|
-
the file path to save the model.
|
|
599
|
-
|
|
600
|
-
raises:
|
|
601
|
-
ClassifierError:
|
|
602
|
-
if the model is not initialized."""
|
|
603
|
-
self._require_model()
|
|
604
|
-
os.makedirs(os.path.dirname(os.path.abspath(path)) or ".", exist_ok=True)
|
|
605
|
-
input_dim = int(next(self.model.parameters()).shape[1])
|
|
606
|
-
torch.save({
|
|
607
|
-
"classifier_name": self.name,
|
|
608
|
-
"state_dict": self.model.state_dict(),
|
|
609
|
-
"input_dim": input_dim,
|
|
610
|
-
"threshold": self.threshold,
|
|
611
|
-
"model_hyperparams": self.model_hyperparams(),
|
|
612
|
-
}, path)
|
|
613
|
-
logger.info("Saved %s to %s", self.name, path)
|
|
614
|
-
|
|
615
|
-
def load_model(self, path: str) -> None:
|
|
616
|
-
"""
|
|
617
|
-
Load the model state and metadata from a file.
|
|
618
|
-
|
|
619
|
-
arguments:
|
|
620
|
-
path:
|
|
621
|
-
the file path to load the model from.
|
|
622
|
-
|
|
623
|
-
raises:
|
|
624
|
-
FileNotFoundError:
|
|
625
|
-
if the model file does not exist.
|
|
626
|
-
ClassifierError:
|
|
627
|
-
if the checkpoint format is unsupported.
|
|
628
|
-
"""
|
|
629
|
-
if not os.path.exists(path): raise FileNotFoundError(path)
|
|
630
|
-
data = torch.load(path, map_location=self.device, weights_only=False)
|
|
631
|
-
if not isinstance(data, dict) or "state_dict" not in data or "input_dim" not in data:
|
|
632
|
-
raise ClassifierError("Unsupported checkpoint format.")
|
|
633
|
-
self.threshold = float(data.get("threshold", self.threshold))
|
|
634
|
-
input_dim = int(data["input_dim"])
|
|
635
|
-
hyper = dict(data.get("model_hyperparams", {}))
|
|
636
|
-
self.model = self.build_model(input_dim, **hyper).to(self.device)
|
|
637
|
-
self.model.load_state_dict(data["state_dict"]); self.model.eval()
|
|
638
|
-
|
|
639
|
-
def _require_model(self) -> None:
|
|
640
|
-
if self.model is None:
|
|
641
|
-
raise ClassifierError("Model not initialized. Call fit() or load_model() first.")
|
|
642
|
-
|
|
643
|
-
@classmethod
|
|
644
|
-
def to_2d_tensor(cls, X, device: str, dtype: torch.dtype) -> torch.Tensor:
|
|
645
|
-
"""
|
|
646
|
-
Convert input to a 2D tensor on the specified device and dtype.
|
|
647
|
-
|
|
648
|
-
arguments:
|
|
649
|
-
X:
|
|
650
|
-
input data as array-like or tensor.
|
|
651
|
-
device:
|
|
652
|
-
target device string.
|
|
653
|
-
dtype:
|
|
654
|
-
target torch dtype.
|
|
655
|
-
|
|
656
|
-
returns:
|
|
657
|
-
2D torch tensor.
|
|
658
|
-
|
|
659
|
-
raises:
|
|
660
|
-
ClassifierError:
|
|
661
|
-
if the input cannot be converted to 2D tensor.
|
|
662
|
-
"""
|
|
663
|
-
if isinstance(X, torch.Tensor):
|
|
664
|
-
t = X.to(device=device, dtype=dtype)
|
|
665
|
-
if t.ndim == 1: t = t.view(1, -1)
|
|
666
|
-
if t.ndim != 2: raise ClassifierError(f"Expected 2D features, got {tuple(t.shape)}")
|
|
667
|
-
return t
|
|
668
|
-
t = torch.tensor(X, device=device, dtype=dtype)
|
|
669
|
-
if t.ndim == 1: t = t.view(1, -1)
|
|
670
|
-
if t.ndim != 2: raise ClassifierError(f"Expected 2D features, got {tuple(t.shape)}")
|
|
671
|
-
return t
|
|
672
|
-
|
|
673
|
-
@staticmethod
|
|
674
|
-
def to_1d_tensor(y, *, device: str, dtype: torch.dtype) -> torch.Tensor:
|
|
675
|
-
"""
|
|
676
|
-
Convert input to a 1D tensor on the specified device and dtype.
|
|
677
|
-
|
|
678
|
-
arguments:
|
|
679
|
-
y:
|
|
680
|
-
input data as array-like or tensor.
|
|
681
|
-
device:
|
|
682
|
-
target device string.
|
|
683
|
-
dtype:
|
|
684
|
-
target torch dtype.
|
|
685
|
-
|
|
686
|
-
returns:
|
|
687
|
-
1D torch tensor.
|
|
688
|
-
|
|
689
|
-
raises:
|
|
690
|
-
ClassifierError:
|
|
691
|
-
if the input cannot be converted to 1D tensor.
|
|
692
|
-
"""
|
|
693
|
-
if isinstance(y, torch.Tensor):
|
|
694
|
-
return y.to(device=device, dtype=dtype).view(-1)
|
|
695
|
-
return torch.tensor(list(y), device=device, dtype=dtype).view(-1)
|
|
696
|
-
|
|
697
|
-
@staticmethod
|
|
698
|
-
def _basic_prf(preds: list[float], labels: list[float]) -> tuple[float, float, float, float]:
|
|
699
|
-
"""
|
|
700
|
-
Compute basic precision, recall, and F1 score.
|
|
701
|
-
|
|
702
|
-
arguments:
|
|
703
|
-
preds:
|
|
704
|
-
list of predicted labels (0.0 or 1.0).
|
|
705
|
-
labels:
|
|
706
|
-
list of true labels (0.0 or 1.0).
|
|
707
|
-
|
|
708
|
-
returns:
|
|
709
|
-
tuple of (accuracy, precision, recall, f1).
|
|
710
|
-
"""
|
|
711
|
-
tp = sum(1 for p, l in zip(preds, labels) if p == 1 and l == 1)
|
|
712
|
-
fp = sum(1 for p, l in zip(preds, labels) if p == 1 and l == 0)
|
|
713
|
-
fn = sum(1 for p, l in zip(preds, labels) if p == 0 and l == 1)
|
|
714
|
-
total = max(len(labels), 1)
|
|
715
|
-
acc = sum(1 for p, l in zip(preds, labels) if p == l) / total
|
|
716
|
-
prec = tp / (tp + fp) if tp + fp > 0 else 0.0
|
|
717
|
-
rec = tp / (tp + fn) if tp + fn > 0 else 0.0
|
|
718
|
-
f1 = (2 * prec * rec / (prec + rec)) if (prec + rec) > 0 else 0.0
|
|
719
|
-
return float(acc), float(prec), float(rec), float(f1)
|
|
720
|
-
|
|
721
|
-
@staticmethod
|
|
722
|
-
def _roc_auc(labels: list[float], scores: list[float]) -> float:
|
|
723
|
-
"""
|
|
724
|
-
Compute ROC AUC using the Mann-Whitney U statistic.
|
|
725
|
-
|
|
726
|
-
arguments:
|
|
727
|
-
labels:
|
|
728
|
-
list of true binary labels (0.0 or 1.0).
|
|
729
|
-
scores:
|
|
730
|
-
list of predicted scores or probabilities.
|
|
731
|
-
|
|
732
|
-
returns:
|
|
733
|
-
ROC AUC value.
|
|
734
|
-
"""
|
|
735
|
-
if len(scores) < 2 or len(set(labels)) < 2: return 0.0
|
|
736
|
-
pairs = sorted(zip(scores, labels), key=lambda x: x[0])
|
|
737
|
-
pos = sum(1 for _, y in pairs if y == 1); neg = sum(1 for _, y in pairs if y == 0)
|
|
738
|
-
if pos == 0 or neg == 0: return 0.0
|
|
739
|
-
rank_sum = 0.0; i = 0
|
|
740
|
-
while i < len(pairs):
|
|
741
|
-
j = i
|
|
742
|
-
while j + 1 < len(pairs) and pairs[j + 1][0] == pairs[i][0]: j += 1
|
|
743
|
-
avg_rank = (i + j + 2) / 2.0
|
|
744
|
-
rank_sum += avg_rank * sum(1 for k in range(i, j + 1) if pairs[k][1] == 1)
|
|
745
|
-
i = j + 1
|
|
746
|
-
U = rank_sum - pos * (pos + 1) / 2.0
|
|
747
|
-
return float(U / (pos * neg))
|