segmentae 1.5.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- segmentae/__init__.py +83 -0
- segmentae/anomaly_detection.py +20 -0
- segmentae/autoencoders/__init__.py +16 -0
- segmentae/autoencoders/batch_norm.py +208 -0
- segmentae/autoencoders/dense.py +211 -0
- segmentae/autoencoders/ensemble.py +219 -0
- segmentae/clusters/__init__.py +18 -0
- segmentae/clusters/clustering.py +171 -0
- segmentae/clusters/models.py +438 -0
- segmentae/clusters/registry.py +75 -0
- segmentae/core/__init__.py +65 -0
- segmentae/core/base.py +108 -0
- segmentae/core/constants.py +91 -0
- segmentae/core/exceptions.py +60 -0
- segmentae/core/types.py +55 -0
- segmentae/data_sources/__init__.py +3 -0
- segmentae/data_sources/examples.py +198 -0
- segmentae/metrics/__init__.py +6 -0
- segmentae/metrics/performance_metrics.py +119 -0
- segmentae/optimization/__init__.py +6 -0
- segmentae/optimization/optimizer.py +375 -0
- segmentae/pipeline/__init__.py +21 -0
- segmentae/pipeline/reconstruction.py +214 -0
- segmentae/pipeline/segmentae.py +562 -0
- segmentae/processing/__init__.py +21 -0
- segmentae/processing/preprocessing.py +263 -0
- segmentae/processing/simplifier.py +74 -0
- segmentae/utils/__init__.py +17 -0
- segmentae/utils/validation.py +94 -0
- segmentae-1.5.20.dist-info/METADATA +393 -0
- segmentae-1.5.20.dist-info/RECORD +34 -0
- segmentae-1.5.20.dist-info/WHEEL +5 -0
- segmentae-1.5.20.dist-info/licenses/LICENSE +21 -0
- segmentae-1.5.20.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Dict
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class PhaseType(str, Enum):
|
|
6
|
+
"""Pipeline execution phases for SegmentAE reconstruction workflow."""
|
|
7
|
+
|
|
8
|
+
EVALUATION = "evaluation"
|
|
9
|
+
TESTING = "testing"
|
|
10
|
+
PREDICTION = "prediction"
|
|
11
|
+
|
|
12
|
+
class ClusterModel(str, Enum):
|
|
13
|
+
"""Available clustering algorithms for data segmentation."""
|
|
14
|
+
|
|
15
|
+
KMEANS = "KMeans"
|
|
16
|
+
MINIBATCH_KMEANS = "MiniBatchKMeans"
|
|
17
|
+
GMM = "GMM"
|
|
18
|
+
AGGLOMERATIVE = "Agglomerative"
|
|
19
|
+
|
|
20
|
+
class ThresholdMetric(str, Enum):
|
|
21
|
+
"""Reconstruction error metrics for anomaly detection thresholding."""
|
|
22
|
+
|
|
23
|
+
MSE = "mse"
|
|
24
|
+
MAE = "mae"
|
|
25
|
+
RMSE = "rmse"
|
|
26
|
+
MAX_ERROR = "max_error"
|
|
27
|
+
|
|
28
|
+
class EncoderType(str, Enum):
|
|
29
|
+
"""Categorical variable encoding methods."""
|
|
30
|
+
|
|
31
|
+
IFREQUENCY = "IFrequencyEncoder"
|
|
32
|
+
LABEL = "LabelEncoder"
|
|
33
|
+
ONEHOT = "OneHotEncoder"
|
|
34
|
+
|
|
35
|
+
class ScalerType(str, Enum):
|
|
36
|
+
"""Feature scaling methods for numerical normalization."""
|
|
37
|
+
|
|
38
|
+
MINMAX = "MinMaxScaler"
|
|
39
|
+
STANDARD = "StandardScaler"
|
|
40
|
+
ROBUST = "RobustScaler"
|
|
41
|
+
|
|
42
|
+
class ImputerType(str, Enum):
|
|
43
|
+
"""Missing value imputation methods."""
|
|
44
|
+
|
|
45
|
+
SIMPLE = "Simple"
|
|
46
|
+
|
|
47
|
+
# Mapping dictionaries
|
|
48
|
+
METRIC_COLUMN_MAP: Dict[ThresholdMetric, str] = {
|
|
49
|
+
ThresholdMetric.MSE: "MSE_Recons_error",
|
|
50
|
+
ThresholdMetric.MAE: "MAE_Recons_error",
|
|
51
|
+
ThresholdMetric.RMSE: "RMSE_Recons_error",
|
|
52
|
+
ThresholdMetric.MAX_ERROR: "Max_Recons_error"
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
METRIC_NAME_MAP: Dict[str, ThresholdMetric] = {
|
|
56
|
+
"mse": ThresholdMetric.MSE,
|
|
57
|
+
"mae": ThresholdMetric.MAE,
|
|
58
|
+
"rmse": ThresholdMetric.RMSE,
|
|
59
|
+
"max_error": ThresholdMetric.MAX_ERROR
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
ENCODER_CLASS_MAP: Dict[EncoderType, str] = {
|
|
63
|
+
EncoderType.IFREQUENCY: "AutoIFrequencyEncoder",
|
|
64
|
+
EncoderType.LABEL: "AutoLabelEncoder",
|
|
65
|
+
EncoderType.ONEHOT: "AutoOneHotEncoder"
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
SCALER_CLASS_MAP: Dict[ScalerType, str] = {
|
|
69
|
+
ScalerType.MINMAX: "AutoMinMaxScaler",
|
|
70
|
+
ScalerType.STANDARD: "AutoStandardScaler",
|
|
71
|
+
ScalerType.ROBUST: "AutoRobustScaler"
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
IMPUTER_CLASS_MAP: Dict[ImputerType, str] = {
|
|
75
|
+
ImputerType.SIMPLE: "AutoSimpleImputer"
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
def get_metric_column_name(metric: ThresholdMetric) -> str:
|
|
79
|
+
|
|
80
|
+
return METRIC_COLUMN_MAP[metric]
|
|
81
|
+
|
|
82
|
+
def parse_threshold_metric(metric_str: str) -> ThresholdMetric:
|
|
83
|
+
|
|
84
|
+
metric_lower = metric_str.lower()
|
|
85
|
+
if metric_lower not in METRIC_NAME_MAP:
|
|
86
|
+
valid_metrics = ", ".join(METRIC_NAME_MAP.keys())
|
|
87
|
+
raise ValueError(
|
|
88
|
+
f"Unknown threshold metric: '{metric_str}'. "
|
|
89
|
+
f"Valid options are: {valid_metrics}"
|
|
90
|
+
)
|
|
91
|
+
return METRIC_NAME_MAP[metric_lower]
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
class SegmentAEError(Exception):
|
|
2
|
+
"""Base exception class for all SegmentAE errors."""
|
|
3
|
+
|
|
4
|
+
def __init__(self, message: str):
|
|
5
|
+
self.message = message
|
|
6
|
+
super().__init__(self.message)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ClusteringError(SegmentAEError):
|
|
10
|
+
"""Exception raised for clustering-related errors."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, message: str):
|
|
13
|
+
super().__init__(f"Clustering Error: {message}")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ReconstructionError(SegmentAEError):
|
|
17
|
+
"""Exception raised for reconstruction-related errors."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, message: str):
|
|
20
|
+
super().__init__(f"Reconstruction Error: {message}")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ValidationError(SegmentAEError):
|
|
24
|
+
"""Exception raised for input validation errors."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, message: str, suggestion: str = None):
|
|
27
|
+
error_msg = f"Validation Error: {message}"
|
|
28
|
+
if suggestion:
|
|
29
|
+
error_msg += f"\nSuggestion: {suggestion}"
|
|
30
|
+
super().__init__(error_msg)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ModelNotFittedError(SegmentAEError):
|
|
34
|
+
"""Exception raised when attempting to use a model before fitting."""
|
|
35
|
+
|
|
36
|
+
def __init__(self, message: str = None, component: str = "Model"):
|
|
37
|
+
if message is None:
|
|
38
|
+
message = (
|
|
39
|
+
f"{component} must be fitted before use. "
|
|
40
|
+
f"Please call the fit() or appropriate fitting method first."
|
|
41
|
+
)
|
|
42
|
+
super().__init__(message)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class ConfigurationError(SegmentAEError):
|
|
46
|
+
"""Exception raised for invalid configuration parameters."""
|
|
47
|
+
|
|
48
|
+
def __init__(self, message: str, valid_options: list = None):
|
|
49
|
+
error_msg = f"Configuration Error: {message}"
|
|
50
|
+
if valid_options:
|
|
51
|
+
options_str = ", ".join(str(opt) for opt in valid_options)
|
|
52
|
+
error_msg += f"\nValid options: {options_str}"
|
|
53
|
+
super().__init__(error_msg)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class AutoencoderError(SegmentAEError):
|
|
57
|
+
"""Exception raised for autoencoder-related errors."""
|
|
58
|
+
|
|
59
|
+
def __init__(self, message: str):
|
|
60
|
+
super().__init__(f"Autoencoder Error: {message}")
|
segmentae/core/types.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Protocol, Union
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
|
|
6
|
+
# Type aliases for commonly used types
|
|
7
|
+
DataFrame = pd.DataFrame
|
|
8
|
+
Series = pd.Series
|
|
9
|
+
NDArray = np.ndarray
|
|
10
|
+
DictStrAny = Dict[str, Any]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AutoencoderProtocol(Protocol):
|
|
14
|
+
"""
|
|
15
|
+
Protocol defining the interface for autoencoder models.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def predict(self, X: Union[DataFrame, NDArray]) -> NDArray:
|
|
19
|
+
"""
|
|
20
|
+
Generate reconstructions from input data.
|
|
21
|
+
"""
|
|
22
|
+
...
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ClusterModelProtocol(Protocol):
|
|
26
|
+
"""
|
|
27
|
+
Protocol defining the interface for clustering models.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def fit(self, X: DataFrame) -> None:
|
|
31
|
+
"""Fit clustering model to data."""
|
|
32
|
+
...
|
|
33
|
+
|
|
34
|
+
def predict(self, X: DataFrame) -> NDArray:
|
|
35
|
+
"""Predict cluster assignments."""
|
|
36
|
+
...
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def n_clusters(self) -> int:
|
|
40
|
+
"""Number of clusters."""
|
|
41
|
+
...
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class PreprocessorProtocol(Protocol):
|
|
45
|
+
"""
|
|
46
|
+
Protocol defining the interface for preprocessing components.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def fit(self, X: DataFrame) -> 'PreprocessorProtocol':
|
|
50
|
+
"""Fit preprocessor to data."""
|
|
51
|
+
...
|
|
52
|
+
|
|
53
|
+
def transform(self, X: DataFrame) -> DataFrame:
|
|
54
|
+
"""Transform data using fitted preprocessor."""
|
|
55
|
+
...
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import pandas as pd
|
|
4
|
+
from sklearn.model_selection import train_test_split
|
|
5
|
+
from ucimlrepo import fetch_ucirepo
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def load_dataset(
|
|
9
|
+
dataset_selection: str = 'htru2_dataset',
|
|
10
|
+
split_ratio: float = 0.8,
|
|
11
|
+
random_state: Optional[int] = 5
|
|
12
|
+
) -> Tuple[pd.DataFrame, pd.DataFrame, str]:
|
|
13
|
+
"""
|
|
14
|
+
Load and preprocess datasets for anomaly detection tasks.
|
|
15
|
+
|
|
16
|
+
Provides access to several benchmark anomaly detection datasets including
|
|
17
|
+
credit card defaults, shuttle data, and pulsar data.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
# Handle different dataset selections
|
|
21
|
+
if dataset_selection == "default_credit_card":
|
|
22
|
+
return _load_credit_card_dataset(split_ratio, random_state)
|
|
23
|
+
|
|
24
|
+
elif dataset_selection == "shuttle_148":
|
|
25
|
+
return _load_shuttle_dataset(split_ratio, random_state)
|
|
26
|
+
|
|
27
|
+
elif dataset_selection == "htru2_dataset":
|
|
28
|
+
return _load_htru2_dataset(split_ratio, random_state)
|
|
29
|
+
|
|
30
|
+
else:
|
|
31
|
+
raise ValueError(
|
|
32
|
+
f"Unknown dataset: '{dataset_selection}'. "
|
|
33
|
+
f"Available options: 'default_credit_card', 'shuttle_148', 'htru2_dataset'"
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
def _load_credit_card_dataset(
|
|
37
|
+
split_ratio: float,
|
|
38
|
+
random_state: Optional[int]
|
|
39
|
+
) -> Tuple[pd.DataFrame, pd.DataFrame, str]:
|
|
40
|
+
"""Load default of credit card clients dataset."""
|
|
41
|
+
# Source: https://archive.ics.uci.edu/dataset/350/default+of+credit+card+clients
|
|
42
|
+
|
|
43
|
+
"""
|
|
44
|
+
This research aimed at the case of customers' default payments in Taiwan and compares the predictive accuracy of probability of default among six data mining methods.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
# Fetch dataset
|
|
48
|
+
info = fetch_ucirepo(id=350)
|
|
49
|
+
|
|
50
|
+
# Concatenate features and targets
|
|
51
|
+
data = pd.concat([info.data.features, info.data.targets], axis=1)
|
|
52
|
+
target = 'Y'
|
|
53
|
+
|
|
54
|
+
# Cast target values to integer
|
|
55
|
+
data[target] = data[target].astype(int)
|
|
56
|
+
|
|
57
|
+
# Separate normal and fraud instances
|
|
58
|
+
normal = data[data[target] == 0]
|
|
59
|
+
fraud = data[data[target] == 1]
|
|
60
|
+
|
|
61
|
+
# Split normal instances into training and testing sets
|
|
62
|
+
train, test = train_test_split(
|
|
63
|
+
normal,
|
|
64
|
+
train_size=split_ratio,
|
|
65
|
+
random_state=random_state
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Combine testing set with fraud instances and shuffle
|
|
69
|
+
test = pd.concat([test, fraud])
|
|
70
|
+
test = test.sample(frac=1, random_state=42)
|
|
71
|
+
|
|
72
|
+
# Reset index
|
|
73
|
+
train = train.reset_index(drop=True)
|
|
74
|
+
test = test.reset_index(drop=True)
|
|
75
|
+
|
|
76
|
+
# Print information
|
|
77
|
+
_print_dataset_info(train, test, target, suggested_split=0.75)
|
|
78
|
+
|
|
79
|
+
return train, test, target
|
|
80
|
+
|
|
81
|
+
def _load_shuttle_dataset(
|
|
82
|
+
split_ratio: float,
|
|
83
|
+
random_state: Optional[int]
|
|
84
|
+
) -> Tuple[pd.DataFrame, pd.DataFrame, str]:
|
|
85
|
+
"""Load Statlog Shuttle dataset."""
|
|
86
|
+
# Source: https://archive.ics.uci.edu/dataset/148/statlog+shuttle
|
|
87
|
+
|
|
88
|
+
""" The shuttle dataset contains 9 attributes all of which are numerical. Approximately 80% of the data belongs to class 1 """
|
|
89
|
+
|
|
90
|
+
# Fetch dataset
|
|
91
|
+
info = fetch_ucirepo(id=148)
|
|
92
|
+
|
|
93
|
+
# Concatenate features and targets
|
|
94
|
+
data = pd.concat([
|
|
95
|
+
info.data.features.reset_index(drop=True),
|
|
96
|
+
info.data.targets.reset_index(drop=True)
|
|
97
|
+
], axis=1)
|
|
98
|
+
|
|
99
|
+
target = 'class'
|
|
100
|
+
|
|
101
|
+
# Adjust target values to binary (1=normal, others=anomaly)
|
|
102
|
+
data[target] = data[target].replace({1: 0, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1, 7: 1})
|
|
103
|
+
data[target] = data[target].astype(int)
|
|
104
|
+
|
|
105
|
+
# Separate normal and anomaly instances
|
|
106
|
+
normal = data[data[target] == 0]
|
|
107
|
+
anomaly = data[data[target] == 1]
|
|
108
|
+
|
|
109
|
+
# Split normal instances
|
|
110
|
+
train, test = train_test_split(
|
|
111
|
+
normal,
|
|
112
|
+
train_size=split_ratio,
|
|
113
|
+
random_state=random_state
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Combine and shuffle
|
|
117
|
+
test = pd.concat([test, anomaly])
|
|
118
|
+
test = test.sample(frac=1, random_state=42)
|
|
119
|
+
|
|
120
|
+
# Reset index
|
|
121
|
+
train = train.reset_index(drop=True)
|
|
122
|
+
test = test.reset_index(drop=True)
|
|
123
|
+
|
|
124
|
+
# Print information
|
|
125
|
+
_print_dataset_info(train, test, target, suggested_split=0.75)
|
|
126
|
+
|
|
127
|
+
return train, test, target
|
|
128
|
+
|
|
129
|
+
def _load_htru2_dataset(
|
|
130
|
+
split_ratio: float,
|
|
131
|
+
random_state: Optional[int]
|
|
132
|
+
) -> Tuple[pd.DataFrame, pd.DataFrame, str]:
|
|
133
|
+
"""Load HTRU2 pulsar dataset."""
|
|
134
|
+
# Source: https://archive.ics.uci.edu/dataset/372/htru2
|
|
135
|
+
|
|
136
|
+
"""
|
|
137
|
+
R. J. Lyon, B. W. Stappers, S. Cooper, J. M. Brooke, J. D. Knowles,
|
|
138
|
+
Fifty Years of Pulsar Candidate Selection: From simple filters to a new principled real-time classification approach,
|
|
139
|
+
Monthly Notices of the Royal Astronomical Society 459 (1), 1104-1123, DOI: 10.1093/mnras/stw656
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
# Fetch dataset
|
|
143
|
+
info = fetch_ucirepo(id=372)
|
|
144
|
+
|
|
145
|
+
# Concatenate features and targets
|
|
146
|
+
data = pd.concat([
|
|
147
|
+
info.data.features.reset_index(drop=True),
|
|
148
|
+
info.data.targets.reset_index(drop=True)
|
|
149
|
+
], axis=1)
|
|
150
|
+
|
|
151
|
+
target = 'class'
|
|
152
|
+
|
|
153
|
+
# Cast target values to integer
|
|
154
|
+
data[target] = data[target].astype(int)
|
|
155
|
+
|
|
156
|
+
# Separate normal and anomaly instances
|
|
157
|
+
normal = data[data[target] == 0]
|
|
158
|
+
anomaly = data[data[target] == 1]
|
|
159
|
+
|
|
160
|
+
# Split normal instances
|
|
161
|
+
train, test = train_test_split(
|
|
162
|
+
normal,
|
|
163
|
+
train_size=split_ratio,
|
|
164
|
+
random_state=random_state
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Combine and shuffle
|
|
168
|
+
test = pd.concat([test, anomaly])
|
|
169
|
+
test = test.sample(frac=1, random_state=42)
|
|
170
|
+
|
|
171
|
+
# Reset index
|
|
172
|
+
train = train.reset_index(drop=True)
|
|
173
|
+
test = test.reset_index(drop=True)
|
|
174
|
+
|
|
175
|
+
# Print information
|
|
176
|
+
_print_dataset_info(train, test, target, suggested_split=0.9)
|
|
177
|
+
|
|
178
|
+
return train, test, target
|
|
179
|
+
|
|
180
|
+
def _print_dataset_info(
|
|
181
|
+
train: pd.DataFrame,
|
|
182
|
+
test: pd.DataFrame,
|
|
183
|
+
target: str,
|
|
184
|
+
suggested_split: float
|
|
185
|
+
) -> None:
|
|
186
|
+
"""Print dataset information."""
|
|
187
|
+
info = {
|
|
188
|
+
"Train Length": len(train),
|
|
189
|
+
"Test Length": len(test),
|
|
190
|
+
"Suggested Split_Ratio": suggested_split
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
# Add target distribution
|
|
194
|
+
for key, value in test[target].value_counts().to_dict().items():
|
|
195
|
+
label = "Anomalies [1]" if key == 1 else "Normal [0]"
|
|
196
|
+
info[label] = value
|
|
197
|
+
|
|
198
|
+
print(info)
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
from typing import Union
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from sklearn.metrics import (
|
|
6
|
+
accuracy_score,
|
|
7
|
+
f1_score,
|
|
8
|
+
max_error,
|
|
9
|
+
mean_absolute_error,
|
|
10
|
+
mean_squared_error,
|
|
11
|
+
precision_score,
|
|
12
|
+
r2_score,
|
|
13
|
+
recall_score,
|
|
14
|
+
root_mean_squared_error,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from segmentae.core.exceptions import ValidationError
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def metrics_classification(
|
|
21
|
+
y_true: Union[pd.Series, np.ndarray],
|
|
22
|
+
y_pred: Union[pd.Series, np.ndarray]
|
|
23
|
+
) -> pd.DataFrame:
|
|
24
|
+
"""
|
|
25
|
+
Calculate classification evaluation metrics.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
DataFrame containing accuracy, precision, recall, and F1 score metrics
|
|
29
|
+
"""
|
|
30
|
+
# Validate inputs
|
|
31
|
+
_validate_classification_inputs(y_true, y_pred)
|
|
32
|
+
|
|
33
|
+
# Calculate metrics with zero_division handling
|
|
34
|
+
accuracy = accuracy_score(y_true, y_pred)
|
|
35
|
+
precision = precision_score(y_true, y_pred, zero_division=0)
|
|
36
|
+
recall = recall_score(y_true, y_pred, zero_division=0)
|
|
37
|
+
f1 = f1_score(y_true, y_pred, zero_division=0)
|
|
38
|
+
|
|
39
|
+
# Create metrics dictionary
|
|
40
|
+
metrics = {
|
|
41
|
+
'Accuracy': accuracy,
|
|
42
|
+
'Precision': precision,
|
|
43
|
+
'Recall': recall,
|
|
44
|
+
'F1 Score': f1
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
# Convert to DataFrame
|
|
48
|
+
return pd.DataFrame(metrics, index=[0])
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def metrics_regression(
|
|
52
|
+
y_true: Union[pd.Series, np.ndarray],
|
|
53
|
+
y_pred: Union[pd.Series, np.ndarray]
|
|
54
|
+
) -> pd.DataFrame:
|
|
55
|
+
"""
|
|
56
|
+
Calculate regression evaluation metrics.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
DataFrame containing MAE, MSE, RMSE, R², and Max Error metrics
|
|
60
|
+
"""
|
|
61
|
+
# Validate inputs
|
|
62
|
+
_validate_regression_inputs(y_true, y_pred)
|
|
63
|
+
|
|
64
|
+
# Calculate metrics
|
|
65
|
+
mae = mean_absolute_error(y_true, y_pred)
|
|
66
|
+
mse = mean_squared_error(y_true, y_pred)
|
|
67
|
+
rmse = root_mean_squared_error(y_true, y_pred, squared=False)
|
|
68
|
+
r2 = r2_score(y_true, y_pred)
|
|
69
|
+
maxerror = max_error(y_true, y_pred)
|
|
70
|
+
|
|
71
|
+
# Create metrics dictionary
|
|
72
|
+
metrics = {
|
|
73
|
+
'Mean Absolute Error': mae,
|
|
74
|
+
'Mean Squared Error': mse,
|
|
75
|
+
'Root Mean Squared Error': rmse,
|
|
76
|
+
'R-squared': r2,
|
|
77
|
+
'Max Error': maxerror
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
# Convert to DataFrame
|
|
81
|
+
return pd.DataFrame(metrics, index=[0])
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _validate_classification_inputs(
|
|
85
|
+
y_true: Union[pd.Series, np.ndarray],
|
|
86
|
+
y_pred: Union[pd.Series, np.ndarray]
|
|
87
|
+
) -> None:
|
|
88
|
+
"""Validate inputs for classification metrics."""
|
|
89
|
+
if len(y_true) != len(y_pred):
|
|
90
|
+
raise ValidationError(
|
|
91
|
+
f"Length mismatch: y_true has {len(y_true)} samples, "
|
|
92
|
+
f"y_pred has {len(y_pred)} samples",
|
|
93
|
+
suggestion="Ensure both arrays have the same number of samples"
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
if len(y_true) == 0:
|
|
97
|
+
raise ValidationError(
|
|
98
|
+
"Empty arrays provided",
|
|
99
|
+
suggestion="Provide non-empty arrays with predictions"
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _validate_regression_inputs(
|
|
104
|
+
y_true: Union[pd.Series, np.ndarray],
|
|
105
|
+
y_pred: Union[pd.Series, np.ndarray]
|
|
106
|
+
) -> None:
|
|
107
|
+
"""Validate inputs for regression metrics."""
|
|
108
|
+
if len(y_true) != len(y_pred):
|
|
109
|
+
raise ValidationError(
|
|
110
|
+
f"Length mismatch: y_true has {len(y_true)} samples, "
|
|
111
|
+
f"y_pred has {len(y_pred)} samples",
|
|
112
|
+
suggestion="Ensure both arrays have the same number of samples"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
if len(y_true) == 0:
|
|
116
|
+
raise ValidationError(
|
|
117
|
+
"Empty arrays provided",
|
|
118
|
+
suggestion="Provide non-empty arrays with predictions"
|
|
119
|
+
)
|