segmentae 1.5.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- segmentae/__init__.py +83 -0
- segmentae/anomaly_detection.py +20 -0
- segmentae/autoencoders/__init__.py +16 -0
- segmentae/autoencoders/batch_norm.py +208 -0
- segmentae/autoencoders/dense.py +211 -0
- segmentae/autoencoders/ensemble.py +219 -0
- segmentae/clusters/__init__.py +18 -0
- segmentae/clusters/clustering.py +171 -0
- segmentae/clusters/models.py +438 -0
- segmentae/clusters/registry.py +75 -0
- segmentae/core/__init__.py +65 -0
- segmentae/core/base.py +108 -0
- segmentae/core/constants.py +91 -0
- segmentae/core/exceptions.py +60 -0
- segmentae/core/types.py +55 -0
- segmentae/data_sources/__init__.py +3 -0
- segmentae/data_sources/examples.py +198 -0
- segmentae/metrics/__init__.py +6 -0
- segmentae/metrics/performance_metrics.py +119 -0
- segmentae/optimization/__init__.py +6 -0
- segmentae/optimization/optimizer.py +375 -0
- segmentae/pipeline/__init__.py +21 -0
- segmentae/pipeline/reconstruction.py +214 -0
- segmentae/pipeline/segmentae.py +562 -0
- segmentae/processing/__init__.py +21 -0
- segmentae/processing/preprocessing.py +263 -0
- segmentae/processing/simplifier.py +74 -0
- segmentae/utils/__init__.py +17 -0
- segmentae/utils/validation.py +94 -0
- segmentae-1.5.20.dist-info/METADATA +393 -0
- segmentae-1.5.20.dist-info/RECORD +34 -0
- segmentae-1.5.20.dist-info/WHEEL +5 -0
- segmentae-1.5.20.dist-info/licenses/LICENSE +21 -0
- segmentae-1.5.20.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,562 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from typing import Any, Dict, Optional, Union
|
|
3
|
+
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from sklearn.metrics import confusion_matrix
|
|
6
|
+
|
|
7
|
+
from segmentae.clusters.clustering import Clustering
|
|
8
|
+
from segmentae.core.constants import (
|
|
9
|
+
PhaseType,
|
|
10
|
+
ThresholdMetric,
|
|
11
|
+
get_metric_column_name,
|
|
12
|
+
)
|
|
13
|
+
from segmentae.core.exceptions import (
|
|
14
|
+
AutoencoderError,
|
|
15
|
+
ConfigurationError,
|
|
16
|
+
ModelNotFittedError,
|
|
17
|
+
ValidationError,
|
|
18
|
+
)
|
|
19
|
+
from segmentae.metrics.performance_metrics import metrics_classification
|
|
20
|
+
from segmentae.pipeline.reconstruction import (
|
|
21
|
+
EvaluationConfig,
|
|
22
|
+
ReconstructionConfig,
|
|
23
|
+
compute_column_metrics,
|
|
24
|
+
compute_reconstruction_errors,
|
|
25
|
+
compute_threshold,
|
|
26
|
+
create_metrics_dataframe,
|
|
27
|
+
detect_anomalies,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
warnings.filterwarnings("ignore", category=Warning)
|
|
31
|
+
|
|
32
|
+
class SegmentAE:
|
|
33
|
+
"""
|
|
34
|
+
SegmentAE integrates autoencoders with clustering for anomaly detection.
|
|
35
|
+
|
|
36
|
+
This class orchestrates the reconstruction, evaluation, and detection pipeline
|
|
37
|
+
by combining autoencoder reconstruction errors with cluster-specific thresholding
|
|
38
|
+
to optimize anomaly detection performance.
|
|
39
|
+
|
|
40
|
+
The workflow consists of three phases:
|
|
41
|
+
1. Reconstruction: Compute reconstruction errors on training data
|
|
42
|
+
2. Evaluation: Test anomaly detection on labeled test data
|
|
43
|
+
3. Detection: Predict anomalies on unlabeled data
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, ae_model: Any, cl_model: Clustering):
|
|
47
|
+
"""
|
|
48
|
+
Initialize SegmentAE pipeline.
|
|
49
|
+
"""
|
|
50
|
+
# Validate inputs
|
|
51
|
+
self._validate_autoencoder(ae_model)
|
|
52
|
+
self._validate_clustering(cl_model)
|
|
53
|
+
|
|
54
|
+
# Store models
|
|
55
|
+
self.ae_model = ae_model
|
|
56
|
+
self.cl_model = cl_model
|
|
57
|
+
|
|
58
|
+
# State management
|
|
59
|
+
self._phase: PhaseType = PhaseType.EVALUATION
|
|
60
|
+
self._threshold: Optional[float] = None
|
|
61
|
+
self._threshold_metric: Optional[ThresholdMetric] = None
|
|
62
|
+
self._metric_column: Optional[str] = None
|
|
63
|
+
|
|
64
|
+
# Results storage (for backward compatibility)
|
|
65
|
+
self.preds_train: Dict[int, Dict] = {}
|
|
66
|
+
self.preds_test: Dict[int, Dict] = {}
|
|
67
|
+
self.reconstruction_eval: Dict[int, Dict] = {}
|
|
68
|
+
self.reconstruction_test: Dict[int, Dict] = {}
|
|
69
|
+
self.results: Dict = {}
|
|
70
|
+
|
|
71
|
+
# Internal state
|
|
72
|
+
self._is_reconstruction_fitted: bool = False
|
|
73
|
+
|
|
74
|
+
def reconstruction(self,
|
|
75
|
+
input_data: pd.DataFrame,
|
|
76
|
+
target_col: Optional[pd.Series] = None,
|
|
77
|
+
threshold_metric: str = "mse") -> Union['SegmentAE', tuple]:
|
|
78
|
+
"""
|
|
79
|
+
Reconstruct input data and compute reconstruction errors per cluster.
|
|
80
|
+
|
|
81
|
+
This method segments data by cluster, generates autoencoder reconstructions,
|
|
82
|
+
and computes reconstruction errors using the specified metric.
|
|
83
|
+
"""
|
|
84
|
+
# Validate and configure
|
|
85
|
+
self._validate_reconstruction_input(input_data, target_col)
|
|
86
|
+
config = ReconstructionConfig(threshold_metric=threshold_metric)
|
|
87
|
+
self._threshold_metric = config.threshold_metric
|
|
88
|
+
self._metric_column = get_metric_column_name(self._threshold_metric)
|
|
89
|
+
|
|
90
|
+
# Get cluster assignments
|
|
91
|
+
cluster_predictions = self._get_cluster_predictions(input_data)
|
|
92
|
+
|
|
93
|
+
# Process each cluster
|
|
94
|
+
cluster_results = self._process_all_clusters(
|
|
95
|
+
input_data=input_data,
|
|
96
|
+
target_col=target_col,
|
|
97
|
+
cluster_predictions=cluster_predictions
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Store results based on phase
|
|
101
|
+
self._store_reconstruction_results(cluster_results)
|
|
102
|
+
|
|
103
|
+
self._is_reconstruction_fitted = True
|
|
104
|
+
return self._return_based_on_phase()
|
|
105
|
+
|
|
106
|
+
def evaluation(self,
|
|
107
|
+
input_data: pd.DataFrame,
|
|
108
|
+
target_col: pd.Series,
|
|
109
|
+
threshold_ratio: float = 1.0) -> Dict[str, Any]:
|
|
110
|
+
"""
|
|
111
|
+
Evaluate anomaly detection performance on labeled test data.
|
|
112
|
+
|
|
113
|
+
Computes cluster-specific thresholds and evaluates detection performance
|
|
114
|
+
against ground truth labels, providing both global and cluster-level metrics.
|
|
115
|
+
"""
|
|
116
|
+
self._validate_fitted_for_evaluation()
|
|
117
|
+
self._validate_evaluation_input(input_data, target_col)
|
|
118
|
+
|
|
119
|
+
config = EvaluationConfig(threshold_ratio=threshold_ratio) # Threshold Ratio Adjusts the anomaly detection sensitivity relative to baseline threshold
|
|
120
|
+
|
|
121
|
+
# Set phase and run reconstruction
|
|
122
|
+
self._phase = PhaseType.TESTING
|
|
123
|
+
self.preds_test, self.reconstruction_test = self.reconstruction(
|
|
124
|
+
input_data=input_data,
|
|
125
|
+
target_col=target_col,
|
|
126
|
+
threshold_metric=self._threshold_metric.value
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# Evaluate each cluster
|
|
130
|
+
cluster_results = self._evaluate_all_clusters(config.threshold_ratio)
|
|
131
|
+
|
|
132
|
+
# Aggregate global results
|
|
133
|
+
global_results = self._aggregate_evaluation_results(
|
|
134
|
+
cluster_results,
|
|
135
|
+
config.threshold_ratio
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
self.results = global_results
|
|
139
|
+
self._phase = PhaseType.EVALUATION # Reset phase
|
|
140
|
+
|
|
141
|
+
return self.results
|
|
142
|
+
|
|
143
|
+
def detections(self,
|
|
144
|
+
input_data: pd.DataFrame,
|
|
145
|
+
threshold_ratio: float = 1.0) -> pd.DataFrame:
|
|
146
|
+
"""
|
|
147
|
+
Perform anomaly detection on unlabeled data.
|
|
148
|
+
|
|
149
|
+
Uses trained cluster-specific thresholds to detect anomalies in new data
|
|
150
|
+
without requiring ground truth labels.
|
|
151
|
+
"""
|
|
152
|
+
self._validate_fitted_for_detection()
|
|
153
|
+
self._validate_input(input_data, "Input data")
|
|
154
|
+
|
|
155
|
+
# Set phase and run reconstruction
|
|
156
|
+
self._phase = PhaseType.PREDICTION
|
|
157
|
+
self.reconstruction(
|
|
158
|
+
input_data=input_data,
|
|
159
|
+
target_col=None,
|
|
160
|
+
threshold_metric=self._threshold_metric.value
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# Detect anomalies per cluster
|
|
164
|
+
predictions = self._detect_anomalies_all_clusters(threshold_ratio)
|
|
165
|
+
|
|
166
|
+
self._phase = PhaseType.EVALUATION # Reset phase
|
|
167
|
+
return predictions
|
|
168
|
+
|
|
169
|
+
def _process_all_clusters(self,
|
|
170
|
+
input_data: pd.DataFrame,
|
|
171
|
+
target_col: Optional[pd.Series],
|
|
172
|
+
cluster_predictions: pd.DataFrame) -> Dict[int, Dict]:
|
|
173
|
+
"""Process reconstruction for all clusters."""
|
|
174
|
+
cluster_model_name = self.cl_model.cluster_model[0]
|
|
175
|
+
results = {}
|
|
176
|
+
|
|
177
|
+
for cluster_id in cluster_predictions[cluster_model_name].unique():
|
|
178
|
+
cluster_data = self._extract_cluster_data(
|
|
179
|
+
input_data,
|
|
180
|
+
target_col,
|
|
181
|
+
cluster_predictions,
|
|
182
|
+
cluster_model_name,
|
|
183
|
+
cluster_id
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# Generate reconstructions and compute errors
|
|
187
|
+
reconstructions = self._reconstruct_cluster(cluster_data['X'])
|
|
188
|
+
errors = self._compute_cluster_errors(
|
|
189
|
+
cluster_data['X'],
|
|
190
|
+
reconstructions
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# Store cluster results
|
|
194
|
+
results[cluster_id] = {
|
|
195
|
+
"cluster": cluster_id,
|
|
196
|
+
"real": cluster_data['X'],
|
|
197
|
+
"y_true": cluster_data['y'],
|
|
198
|
+
"predictions": reconstructions,
|
|
199
|
+
"indexs": cluster_data['indices'],
|
|
200
|
+
"errors": errors
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
return results
|
|
204
|
+
|
|
205
|
+
def _extract_cluster_data(self,
|
|
206
|
+
input_data: pd.DataFrame,
|
|
207
|
+
target_col: Optional[pd.Series],
|
|
208
|
+
cluster_predictions: pd.DataFrame,
|
|
209
|
+
cluster_model_name: str,
|
|
210
|
+
cluster_id: int) -> Dict[str, Any]:
|
|
211
|
+
"""Extract data for a specific cluster."""
|
|
212
|
+
cluster_indices = cluster_predictions.index[
|
|
213
|
+
cluster_predictions[cluster_model_name] == cluster_id
|
|
214
|
+
].tolist()
|
|
215
|
+
|
|
216
|
+
X_cluster = input_data.loc[cluster_indices]
|
|
217
|
+
y_cluster = target_col.loc[cluster_indices] if target_col is not None else None
|
|
218
|
+
|
|
219
|
+
return {
|
|
220
|
+
'X': X_cluster,
|
|
221
|
+
'y': y_cluster,
|
|
222
|
+
'indices': cluster_indices
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
def _reconstruct_cluster(self, X_cluster: pd.DataFrame) -> pd.DataFrame:
|
|
226
|
+
"""Generate autoencoder reconstructions for cluster data."""
|
|
227
|
+
try:
|
|
228
|
+
predictions = self.ae_model.predict(X_cluster)
|
|
229
|
+
return pd.DataFrame(predictions, columns=X_cluster.columns).astype(float)
|
|
230
|
+
except Exception as e:
|
|
231
|
+
raise AutoencoderError(
|
|
232
|
+
f"Failed to generate reconstructions: {str(e)}"
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
def _compute_cluster_errors(self,
|
|
236
|
+
real_values: pd.DataFrame,
|
|
237
|
+
predictions: pd.DataFrame) -> Dict[str, pd.DataFrame]:
|
|
238
|
+
"""Compute reconstruction errors for cluster."""
|
|
239
|
+
real_np = real_values.values
|
|
240
|
+
pred_np = predictions.values
|
|
241
|
+
|
|
242
|
+
# Compute per-row errors
|
|
243
|
+
mse, mae, rmse, max_err = compute_reconstruction_errors(real_np, pred_np)
|
|
244
|
+
metrics_df = create_metrics_dataframe(mse, mae, rmse, max_err)
|
|
245
|
+
|
|
246
|
+
# Compute per-column metrics
|
|
247
|
+
col_metrics = compute_column_metrics(
|
|
248
|
+
real_np,
|
|
249
|
+
pred_np,
|
|
250
|
+
list(real_values.columns),
|
|
251
|
+
0 # cluster_id will be set by caller
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
return {
|
|
255
|
+
'metrics': metrics_df,
|
|
256
|
+
'column_metrics': col_metrics
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
def _evaluate_all_clusters(self, threshold_ratio: float) -> list:
|
|
260
|
+
"""Evaluate anomaly detection for all clusters."""
|
|
261
|
+
cluster_results = []
|
|
262
|
+
|
|
263
|
+
for cluster_id in self.reconstruction_eval.keys():
|
|
264
|
+
if cluster_id not in self.reconstruction_test:
|
|
265
|
+
print(f"Warning: Cluster {cluster_id} not found in test data")
|
|
266
|
+
continue
|
|
267
|
+
|
|
268
|
+
result = self._evaluate_single_cluster(cluster_id, threshold_ratio)
|
|
269
|
+
cluster_results.append(result)
|
|
270
|
+
|
|
271
|
+
return cluster_results
|
|
272
|
+
|
|
273
|
+
def _evaluate_single_cluster(self,
|
|
274
|
+
cluster_id: int,
|
|
275
|
+
threshold_ratio: float) -> Dict[str, Any]:
|
|
276
|
+
"""Evaluate a single cluster."""
|
|
277
|
+
# Compute threshold
|
|
278
|
+
threshold = self._compute_cluster_threshold(cluster_id, threshold_ratio)
|
|
279
|
+
|
|
280
|
+
# Get test data
|
|
281
|
+
metrics_test = self.reconstruction_test[cluster_id]["metrics"]
|
|
282
|
+
predictions = self.preds_test[cluster_id]["predictions"].copy()
|
|
283
|
+
y_test = self.preds_test[cluster_id]["y_true"].reset_index(drop=True)
|
|
284
|
+
indices = self.preds_test[cluster_id]["indexs"]
|
|
285
|
+
|
|
286
|
+
# Classify anomalies
|
|
287
|
+
predictions['Predicted Anomalies'] = detect_anomalies(
|
|
288
|
+
metrics_test[self._metric_column],
|
|
289
|
+
threshold
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
# Compute metrics
|
|
293
|
+
cm = confusion_matrix(y_test, predictions['Predicted Anomalies'])
|
|
294
|
+
metrics = metrics_classification(y_test, predictions['Predicted Anomalies'])
|
|
295
|
+
metrics["N_Cluster"] = cluster_id
|
|
296
|
+
metrics["Threshold Metric"] = self._threshold_metric.value.upper()
|
|
297
|
+
metrics["Threshold Value"] = round(threshold, 6)
|
|
298
|
+
|
|
299
|
+
return {
|
|
300
|
+
"cluster_id": cluster_id,
|
|
301
|
+
"metrics": metrics,
|
|
302
|
+
"confusion_matrix": cm,
|
|
303
|
+
"predictions": predictions,
|
|
304
|
+
"indices": indices,
|
|
305
|
+
"y_test": y_test
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
def _compute_cluster_threshold(self,
|
|
309
|
+
cluster_id: int,
|
|
310
|
+
threshold_ratio: float) -> float:
|
|
311
|
+
"""Compute reconstruction threshold for a cluster."""
|
|
312
|
+
rec_errors = self.reconstruction_eval[cluster_id]["metrics"][self._metric_column]
|
|
313
|
+
threshold = compute_threshold(rec_errors, threshold_ratio)
|
|
314
|
+
print(f"Cluster {cluster_id} || Reconstruction Threshold: {round(threshold, 5)}")
|
|
315
|
+
|
|
316
|
+
# Print empty line after last cluster
|
|
317
|
+
if cluster_id == len(self.reconstruction_eval) - 1:
|
|
318
|
+
print("")
|
|
319
|
+
|
|
320
|
+
return threshold
|
|
321
|
+
|
|
322
|
+
def _detect_anomalies_all_clusters(self,
|
|
323
|
+
threshold_ratio: float) -> pd.DataFrame:
|
|
324
|
+
"""Detect anomalies across all clusters."""
|
|
325
|
+
all_predictions = []
|
|
326
|
+
|
|
327
|
+
for cluster_id in self.reconstruction_eval.keys():
|
|
328
|
+
if cluster_id not in self.reconstruction_final:
|
|
329
|
+
continue
|
|
330
|
+
|
|
331
|
+
cluster_preds = self._detect_cluster_anomalies(
|
|
332
|
+
cluster_id,
|
|
333
|
+
threshold_ratio
|
|
334
|
+
)
|
|
335
|
+
all_predictions.append(cluster_preds)
|
|
336
|
+
|
|
337
|
+
# Aggregate and sort by original index
|
|
338
|
+
if not all_predictions:
|
|
339
|
+
raise ValidationError("No cluster predictions available")
|
|
340
|
+
|
|
341
|
+
final_predictions = pd.concat(all_predictions, ignore_index=True)
|
|
342
|
+
final_predictions = final_predictions.sort_values(by='_index')
|
|
343
|
+
final_predictions = final_predictions.reset_index(drop=True)
|
|
344
|
+
|
|
345
|
+
return final_predictions.drop('_index', axis=1)
|
|
346
|
+
|
|
347
|
+
def _detect_cluster_anomalies(self,
|
|
348
|
+
cluster_id: int,
|
|
349
|
+
threshold_ratio: float) -> pd.DataFrame:
|
|
350
|
+
"""Detect anomalies for a single cluster."""
|
|
351
|
+
# Compute threshold (without printing)
|
|
352
|
+
rec_errors = self.reconstruction_eval[cluster_id]["metrics"][self._metric_column]
|
|
353
|
+
threshold = compute_threshold(rec_errors, threshold_ratio)
|
|
354
|
+
|
|
355
|
+
# Get predictions
|
|
356
|
+
recons_metrics = self.reconstruction_final[cluster_id]["metrics"]
|
|
357
|
+
predictions = self.preds_final[cluster_id]["predictions"].copy()
|
|
358
|
+
indices = self.preds_final[cluster_id]["indexs"]
|
|
359
|
+
|
|
360
|
+
# Detect anomalies
|
|
361
|
+
predictions['Predicted Anomalies'] = detect_anomalies(
|
|
362
|
+
recons_metrics[self._metric_column],
|
|
363
|
+
threshold
|
|
364
|
+
)
|
|
365
|
+
predictions['_index'] = indices
|
|
366
|
+
|
|
367
|
+
return predictions
|
|
368
|
+
|
|
369
|
+
def _aggregate_evaluation_results(self,
|
|
370
|
+
cluster_results: list,
|
|
371
|
+
threshold_ratio: float) -> Dict[str, Any]:
|
|
372
|
+
"""Aggregate cluster evaluation results into global metrics."""
|
|
373
|
+
# Cluster-level metrics
|
|
374
|
+
cluster_metrics = pd.concat(
|
|
375
|
+
[result["metrics"] for result in cluster_results],
|
|
376
|
+
ignore_index=True
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
# Confusion matrices
|
|
380
|
+
confusion_matrices = {
|
|
381
|
+
result["cluster_id"]: {
|
|
382
|
+
f"CM_{result['cluster_id']}": result["confusion_matrix"]
|
|
383
|
+
}
|
|
384
|
+
for result in cluster_results
|
|
385
|
+
}
|
|
386
|
+
|
|
387
|
+
# Global predictions
|
|
388
|
+
all_predictions = []
|
|
389
|
+
for result in cluster_results:
|
|
390
|
+
pred_df = pd.DataFrame({
|
|
391
|
+
'index': result["indices"],
|
|
392
|
+
'y_test': result["y_test"],
|
|
393
|
+
'Predicted Anomalies': result["predictions"]['Predicted Anomalies']
|
|
394
|
+
})
|
|
395
|
+
all_predictions.append(pred_df)
|
|
396
|
+
|
|
397
|
+
ytpred = pd.concat(all_predictions, ignore_index=True)
|
|
398
|
+
ytpred = ytpred.sort_values(by='index').set_index('index')
|
|
399
|
+
|
|
400
|
+
# Global metrics
|
|
401
|
+
global_metrics = metrics_classification(
|
|
402
|
+
ytpred['y_test'],
|
|
403
|
+
ytpred['Predicted Anomalies']
|
|
404
|
+
)
|
|
405
|
+
global_metrics["Threshold Metric"] = self._threshold_metric.value.upper()
|
|
406
|
+
global_metrics["Threshold Ratio"] = threshold_ratio
|
|
407
|
+
|
|
408
|
+
return {
|
|
409
|
+
"global metrics": global_metrics,
|
|
410
|
+
"clusters metrics": cluster_metrics,
|
|
411
|
+
"confusion matrix": confusion_matrices,
|
|
412
|
+
"y_true vs y_pred": ytpred
|
|
413
|
+
}
|
|
414
|
+
|
|
415
|
+
# Storage and retrieval methods
|
|
416
|
+
|
|
417
|
+
def _store_reconstruction_results(self, cluster_results: Dict) -> None:
|
|
418
|
+
"""Store reconstruction results based on current phase."""
|
|
419
|
+
# Convert cluster results to metrics format
|
|
420
|
+
metrics_results = {}
|
|
421
|
+
preds_results = {}
|
|
422
|
+
|
|
423
|
+
for cluster_id, result in cluster_results.items():
|
|
424
|
+
metrics_results[cluster_id] = {
|
|
425
|
+
"cluster": result["cluster"],
|
|
426
|
+
"metrics": result["errors"]["metrics"],
|
|
427
|
+
"column_metrics": result["errors"]["column_metrics"],
|
|
428
|
+
"indexs": result["indexs"]
|
|
429
|
+
}
|
|
430
|
+
|
|
431
|
+
preds_results[cluster_id] = {
|
|
432
|
+
"cluster": result["cluster"],
|
|
433
|
+
"real": result["real"],
|
|
434
|
+
"y_true": result["y_true"],
|
|
435
|
+
"predictions": result["predictions"],
|
|
436
|
+
"indexs": result["indexs"]
|
|
437
|
+
}
|
|
438
|
+
|
|
439
|
+
if self._phase == PhaseType.EVALUATION:
|
|
440
|
+
self.preds_train = preds_results
|
|
441
|
+
self.reconstruction_eval = metrics_results
|
|
442
|
+
elif self._phase == PhaseType.TESTING:
|
|
443
|
+
self.preds_test = preds_results
|
|
444
|
+
self.reconstruction_test = metrics_results
|
|
445
|
+
elif self._phase == PhaseType.PREDICTION:
|
|
446
|
+
self.preds_final = preds_results
|
|
447
|
+
self.reconstruction_final = metrics_results
|
|
448
|
+
|
|
449
|
+
def _return_based_on_phase(self) -> Union['SegmentAE', tuple]:
|
|
450
|
+
"""Return appropriate results based on phase."""
|
|
451
|
+
if self._phase == PhaseType.EVALUATION:
|
|
452
|
+
return self
|
|
453
|
+
elif self._phase == PhaseType.TESTING:
|
|
454
|
+
return self.preds_test, self.reconstruction_test
|
|
455
|
+
elif self._phase == PhaseType.PREDICTION:
|
|
456
|
+
return self.preds_final, self.reconstruction_final
|
|
457
|
+
|
|
458
|
+
# Helper methods
|
|
459
|
+
|
|
460
|
+
def _get_cluster_predictions(self, input_data: pd.DataFrame) -> pd.DataFrame:
|
|
461
|
+
"""Get cluster assignments for input data."""
|
|
462
|
+
return self.cl_model.cluster_prediction(X=input_data)
|
|
463
|
+
|
|
464
|
+
def _validate_autoencoder(self, ae_model: Any) -> None:
|
|
465
|
+
"""Validate that autoencoder has required interface."""
|
|
466
|
+
if not hasattr(ae_model, 'predict'):
|
|
467
|
+
raise ConfigurationError(
|
|
468
|
+
"Autoencoder must have a 'predict' method. "
|
|
469
|
+
"Ensure you're passing a trained Keras model or built-in autoencoder."
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
def _validate_clustering(self, cl_model: Clustering) -> None:
|
|
473
|
+
"""Validate clustering model."""
|
|
474
|
+
if not isinstance(cl_model, Clustering):
|
|
475
|
+
raise ConfigurationError(
|
|
476
|
+
"cl_model must be an instance of Clustering",
|
|
477
|
+
valid_options=["Clustering"]
|
|
478
|
+
)
|
|
479
|
+
if not cl_model._is_fitted:
|
|
480
|
+
raise ModelNotFittedError(
|
|
481
|
+
component="Clustering",
|
|
482
|
+
message="Clustering model must be fitted before use. "
|
|
483
|
+
"Call clustering_fit(X) method first."
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
def _validate_reconstruction_input(self,
|
|
487
|
+
input_data: pd.DataFrame,
|
|
488
|
+
target_col: Optional[pd.Series]) -> None:
|
|
489
|
+
"""Validate reconstruction inputs."""
|
|
490
|
+
self._validate_input(input_data, "Input data")
|
|
491
|
+
|
|
492
|
+
if target_col is not None:
|
|
493
|
+
if not isinstance(target_col, pd.Series):
|
|
494
|
+
raise ValidationError(
|
|
495
|
+
f"target_col must be a pandas Series, got {type(target_col).__name__}",
|
|
496
|
+
suggestion="Convert to Series using pd.Series() or use DataFrame column"
|
|
497
|
+
)
|
|
498
|
+
if len(target_col) != len(input_data):
|
|
499
|
+
raise ValidationError(
|
|
500
|
+
f"target_col length ({len(target_col)}) must match "
|
|
501
|
+
f"input_data length ({len(input_data)})"
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
def _validate_evaluation_input(self,
|
|
505
|
+
input_data: pd.DataFrame,
|
|
506
|
+
target_col: pd.Series) -> None:
|
|
507
|
+
"""Validate evaluation inputs."""
|
|
508
|
+
self._validate_input(input_data, "Input data")
|
|
509
|
+
|
|
510
|
+
if not isinstance(target_col, pd.Series):
|
|
511
|
+
raise ValidationError(
|
|
512
|
+
f"target_col must be a pandas Series, got {type(target_col).__name__}",
|
|
513
|
+
suggestion="Use test[target_column] to extract Series"
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
if len(target_col) != len(input_data):
|
|
517
|
+
raise ValidationError(
|
|
518
|
+
f"target_col length ({len(target_col)}) must match "
|
|
519
|
+
f"input_data length ({len(input_data)})"
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
def _validate_input(self, X: pd.DataFrame, context: str = "Input") -> None:
|
|
523
|
+
"""Validate input DataFrame."""
|
|
524
|
+
if not isinstance(X, pd.DataFrame):
|
|
525
|
+
raise ValidationError(
|
|
526
|
+
f"{context} must be a pandas DataFrame, got {type(X).__name__}",
|
|
527
|
+
suggestion="Convert to DataFrame using pd.DataFrame()"
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
if X.empty:
|
|
531
|
+
raise ValidationError(
|
|
532
|
+
f"{context} DataFrame is empty",
|
|
533
|
+
suggestion="Ensure your dataset contains data"
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
def _validate_fitted_for_evaluation(self) -> None:
|
|
537
|
+
"""Validate that reconstruction has been performed."""
|
|
538
|
+
if not self._is_reconstruction_fitted:
|
|
539
|
+
raise ModelNotFittedError(
|
|
540
|
+
component="SegmentAE",
|
|
541
|
+
message="Must call reconstruction() before evaluation(). "
|
|
542
|
+
"Example: sg.reconstruction(X_train, threshold_metric='mse')"
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
def _validate_fitted_for_detection(self) -> None:
|
|
546
|
+
"""Validate that reconstruction has been performed."""
|
|
547
|
+
if not self._is_reconstruction_fitted:
|
|
548
|
+
raise ModelNotFittedError(
|
|
549
|
+
component="SegmentAE",
|
|
550
|
+
message="Must call reconstruction() before detections(). "
|
|
551
|
+
"Example: sg.reconstruction(X_train, threshold_metric='mse')"
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
def __repr__(self) -> str:
|
|
555
|
+
"""String representation of SegmentAE."""
|
|
556
|
+
ae_name = type(self.ae_model).__name__
|
|
557
|
+
return (
|
|
558
|
+
f"SegmentAE("
|
|
559
|
+
f"autoencoder={ae_name}, "
|
|
560
|
+
f"clustering={self.cl_model.cluster_model}, "
|
|
561
|
+
f"fitted={self._is_reconstruction_fitted})"
|
|
562
|
+
)
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from segmentae.pipeline.reconstruction import (
|
|
2
|
+
ClusterReconstruction,
|
|
3
|
+
ReconstructionMetrics,
|
|
4
|
+
compute_column_metrics,
|
|
5
|
+
compute_reconstruction_errors,
|
|
6
|
+
compute_threshold,
|
|
7
|
+
detect_anomalies,
|
|
8
|
+
)
|
|
9
|
+
from segmentae.pipeline.segmentae import EvaluationConfig, ReconstructionConfig, SegmentAE
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
'SegmentAE',
|
|
13
|
+
'ReconstructionConfig',
|
|
14
|
+
'EvaluationConfig',
|
|
15
|
+
'ClusterReconstruction',
|
|
16
|
+
'ReconstructionMetrics',
|
|
17
|
+
'compute_reconstruction_errors',
|
|
18
|
+
'compute_column_metrics',
|
|
19
|
+
'compute_threshold',
|
|
20
|
+
'detect_anomalies'
|
|
21
|
+
]
|