segmentae 1.5.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,562 @@
1
+ import warnings
2
+ from typing import Any, Dict, Optional, Union
3
+
4
+ import pandas as pd
5
+ from sklearn.metrics import confusion_matrix
6
+
7
+ from segmentae.clusters.clustering import Clustering
8
+ from segmentae.core.constants import (
9
+ PhaseType,
10
+ ThresholdMetric,
11
+ get_metric_column_name,
12
+ )
13
+ from segmentae.core.exceptions import (
14
+ AutoencoderError,
15
+ ConfigurationError,
16
+ ModelNotFittedError,
17
+ ValidationError,
18
+ )
19
+ from segmentae.metrics.performance_metrics import metrics_classification
20
+ from segmentae.pipeline.reconstruction import (
21
+ EvaluationConfig,
22
+ ReconstructionConfig,
23
+ compute_column_metrics,
24
+ compute_reconstruction_errors,
25
+ compute_threshold,
26
+ create_metrics_dataframe,
27
+ detect_anomalies,
28
+ )
29
+
30
+ warnings.filterwarnings("ignore", category=Warning)
31
+
32
+ class SegmentAE:
33
+ """
34
+ SegmentAE integrates autoencoders with clustering for anomaly detection.
35
+
36
+ This class orchestrates the reconstruction, evaluation, and detection pipeline
37
+ by combining autoencoder reconstruction errors with cluster-specific thresholding
38
+ to optimize anomaly detection performance.
39
+
40
+ The workflow consists of three phases:
41
+ 1. Reconstruction: Compute reconstruction errors on training data
42
+ 2. Evaluation: Test anomaly detection on labeled test data
43
+ 3. Detection: Predict anomalies on unlabeled data
44
+ """
45
+
46
+ def __init__(self, ae_model: Any, cl_model: Clustering):
47
+ """
48
+ Initialize SegmentAE pipeline.
49
+ """
50
+ # Validate inputs
51
+ self._validate_autoencoder(ae_model)
52
+ self._validate_clustering(cl_model)
53
+
54
+ # Store models
55
+ self.ae_model = ae_model
56
+ self.cl_model = cl_model
57
+
58
+ # State management
59
+ self._phase: PhaseType = PhaseType.EVALUATION
60
+ self._threshold: Optional[float] = None
61
+ self._threshold_metric: Optional[ThresholdMetric] = None
62
+ self._metric_column: Optional[str] = None
63
+
64
+ # Results storage (for backward compatibility)
65
+ self.preds_train: Dict[int, Dict] = {}
66
+ self.preds_test: Dict[int, Dict] = {}
67
+ self.reconstruction_eval: Dict[int, Dict] = {}
68
+ self.reconstruction_test: Dict[int, Dict] = {}
69
+ self.results: Dict = {}
70
+
71
+ # Internal state
72
+ self._is_reconstruction_fitted: bool = False
73
+
74
+ def reconstruction(self,
75
+ input_data: pd.DataFrame,
76
+ target_col: Optional[pd.Series] = None,
77
+ threshold_metric: str = "mse") -> Union['SegmentAE', tuple]:
78
+ """
79
+ Reconstruct input data and compute reconstruction errors per cluster.
80
+
81
+ This method segments data by cluster, generates autoencoder reconstructions,
82
+ and computes reconstruction errors using the specified metric.
83
+ """
84
+ # Validate and configure
85
+ self._validate_reconstruction_input(input_data, target_col)
86
+ config = ReconstructionConfig(threshold_metric=threshold_metric)
87
+ self._threshold_metric = config.threshold_metric
88
+ self._metric_column = get_metric_column_name(self._threshold_metric)
89
+
90
+ # Get cluster assignments
91
+ cluster_predictions = self._get_cluster_predictions(input_data)
92
+
93
+ # Process each cluster
94
+ cluster_results = self._process_all_clusters(
95
+ input_data=input_data,
96
+ target_col=target_col,
97
+ cluster_predictions=cluster_predictions
98
+ )
99
+
100
+ # Store results based on phase
101
+ self._store_reconstruction_results(cluster_results)
102
+
103
+ self._is_reconstruction_fitted = True
104
+ return self._return_based_on_phase()
105
+
106
+ def evaluation(self,
107
+ input_data: pd.DataFrame,
108
+ target_col: pd.Series,
109
+ threshold_ratio: float = 1.0) -> Dict[str, Any]:
110
+ """
111
+ Evaluate anomaly detection performance on labeled test data.
112
+
113
+ Computes cluster-specific thresholds and evaluates detection performance
114
+ against ground truth labels, providing both global and cluster-level metrics.
115
+ """
116
+ self._validate_fitted_for_evaluation()
117
+ self._validate_evaluation_input(input_data, target_col)
118
+
119
+ config = EvaluationConfig(threshold_ratio=threshold_ratio) # Threshold Ratio Adjusts the anomaly detection sensitivity relative to baseline threshold
120
+
121
+ # Set phase and run reconstruction
122
+ self._phase = PhaseType.TESTING
123
+ self.preds_test, self.reconstruction_test = self.reconstruction(
124
+ input_data=input_data,
125
+ target_col=target_col,
126
+ threshold_metric=self._threshold_metric.value
127
+ )
128
+
129
+ # Evaluate each cluster
130
+ cluster_results = self._evaluate_all_clusters(config.threshold_ratio)
131
+
132
+ # Aggregate global results
133
+ global_results = self._aggregate_evaluation_results(
134
+ cluster_results,
135
+ config.threshold_ratio
136
+ )
137
+
138
+ self.results = global_results
139
+ self._phase = PhaseType.EVALUATION # Reset phase
140
+
141
+ return self.results
142
+
143
+ def detections(self,
144
+ input_data: pd.DataFrame,
145
+ threshold_ratio: float = 1.0) -> pd.DataFrame:
146
+ """
147
+ Perform anomaly detection on unlabeled data.
148
+
149
+ Uses trained cluster-specific thresholds to detect anomalies in new data
150
+ without requiring ground truth labels.
151
+ """
152
+ self._validate_fitted_for_detection()
153
+ self._validate_input(input_data, "Input data")
154
+
155
+ # Set phase and run reconstruction
156
+ self._phase = PhaseType.PREDICTION
157
+ self.reconstruction(
158
+ input_data=input_data,
159
+ target_col=None,
160
+ threshold_metric=self._threshold_metric.value
161
+ )
162
+
163
+ # Detect anomalies per cluster
164
+ predictions = self._detect_anomalies_all_clusters(threshold_ratio)
165
+
166
+ self._phase = PhaseType.EVALUATION # Reset phase
167
+ return predictions
168
+
169
+ def _process_all_clusters(self,
170
+ input_data: pd.DataFrame,
171
+ target_col: Optional[pd.Series],
172
+ cluster_predictions: pd.DataFrame) -> Dict[int, Dict]:
173
+ """Process reconstruction for all clusters."""
174
+ cluster_model_name = self.cl_model.cluster_model[0]
175
+ results = {}
176
+
177
+ for cluster_id in cluster_predictions[cluster_model_name].unique():
178
+ cluster_data = self._extract_cluster_data(
179
+ input_data,
180
+ target_col,
181
+ cluster_predictions,
182
+ cluster_model_name,
183
+ cluster_id
184
+ )
185
+
186
+ # Generate reconstructions and compute errors
187
+ reconstructions = self._reconstruct_cluster(cluster_data['X'])
188
+ errors = self._compute_cluster_errors(
189
+ cluster_data['X'],
190
+ reconstructions
191
+ )
192
+
193
+ # Store cluster results
194
+ results[cluster_id] = {
195
+ "cluster": cluster_id,
196
+ "real": cluster_data['X'],
197
+ "y_true": cluster_data['y'],
198
+ "predictions": reconstructions,
199
+ "indexs": cluster_data['indices'],
200
+ "errors": errors
201
+ }
202
+
203
+ return results
204
+
205
+ def _extract_cluster_data(self,
206
+ input_data: pd.DataFrame,
207
+ target_col: Optional[pd.Series],
208
+ cluster_predictions: pd.DataFrame,
209
+ cluster_model_name: str,
210
+ cluster_id: int) -> Dict[str, Any]:
211
+ """Extract data for a specific cluster."""
212
+ cluster_indices = cluster_predictions.index[
213
+ cluster_predictions[cluster_model_name] == cluster_id
214
+ ].tolist()
215
+
216
+ X_cluster = input_data.loc[cluster_indices]
217
+ y_cluster = target_col.loc[cluster_indices] if target_col is not None else None
218
+
219
+ return {
220
+ 'X': X_cluster,
221
+ 'y': y_cluster,
222
+ 'indices': cluster_indices
223
+ }
224
+
225
+ def _reconstruct_cluster(self, X_cluster: pd.DataFrame) -> pd.DataFrame:
226
+ """Generate autoencoder reconstructions for cluster data."""
227
+ try:
228
+ predictions = self.ae_model.predict(X_cluster)
229
+ return pd.DataFrame(predictions, columns=X_cluster.columns).astype(float)
230
+ except Exception as e:
231
+ raise AutoencoderError(
232
+ f"Failed to generate reconstructions: {str(e)}"
233
+ )
234
+
235
+ def _compute_cluster_errors(self,
236
+ real_values: pd.DataFrame,
237
+ predictions: pd.DataFrame) -> Dict[str, pd.DataFrame]:
238
+ """Compute reconstruction errors for cluster."""
239
+ real_np = real_values.values
240
+ pred_np = predictions.values
241
+
242
+ # Compute per-row errors
243
+ mse, mae, rmse, max_err = compute_reconstruction_errors(real_np, pred_np)
244
+ metrics_df = create_metrics_dataframe(mse, mae, rmse, max_err)
245
+
246
+ # Compute per-column metrics
247
+ col_metrics = compute_column_metrics(
248
+ real_np,
249
+ pred_np,
250
+ list(real_values.columns),
251
+ 0 # cluster_id will be set by caller
252
+ )
253
+
254
+ return {
255
+ 'metrics': metrics_df,
256
+ 'column_metrics': col_metrics
257
+ }
258
+
259
+ def _evaluate_all_clusters(self, threshold_ratio: float) -> list:
260
+ """Evaluate anomaly detection for all clusters."""
261
+ cluster_results = []
262
+
263
+ for cluster_id in self.reconstruction_eval.keys():
264
+ if cluster_id not in self.reconstruction_test:
265
+ print(f"Warning: Cluster {cluster_id} not found in test data")
266
+ continue
267
+
268
+ result = self._evaluate_single_cluster(cluster_id, threshold_ratio)
269
+ cluster_results.append(result)
270
+
271
+ return cluster_results
272
+
273
+ def _evaluate_single_cluster(self,
274
+ cluster_id: int,
275
+ threshold_ratio: float) -> Dict[str, Any]:
276
+ """Evaluate a single cluster."""
277
+ # Compute threshold
278
+ threshold = self._compute_cluster_threshold(cluster_id, threshold_ratio)
279
+
280
+ # Get test data
281
+ metrics_test = self.reconstruction_test[cluster_id]["metrics"]
282
+ predictions = self.preds_test[cluster_id]["predictions"].copy()
283
+ y_test = self.preds_test[cluster_id]["y_true"].reset_index(drop=True)
284
+ indices = self.preds_test[cluster_id]["indexs"]
285
+
286
+ # Classify anomalies
287
+ predictions['Predicted Anomalies'] = detect_anomalies(
288
+ metrics_test[self._metric_column],
289
+ threshold
290
+ )
291
+
292
+ # Compute metrics
293
+ cm = confusion_matrix(y_test, predictions['Predicted Anomalies'])
294
+ metrics = metrics_classification(y_test, predictions['Predicted Anomalies'])
295
+ metrics["N_Cluster"] = cluster_id
296
+ metrics["Threshold Metric"] = self._threshold_metric.value.upper()
297
+ metrics["Threshold Value"] = round(threshold, 6)
298
+
299
+ return {
300
+ "cluster_id": cluster_id,
301
+ "metrics": metrics,
302
+ "confusion_matrix": cm,
303
+ "predictions": predictions,
304
+ "indices": indices,
305
+ "y_test": y_test
306
+ }
307
+
308
+ def _compute_cluster_threshold(self,
309
+ cluster_id: int,
310
+ threshold_ratio: float) -> float:
311
+ """Compute reconstruction threshold for a cluster."""
312
+ rec_errors = self.reconstruction_eval[cluster_id]["metrics"][self._metric_column]
313
+ threshold = compute_threshold(rec_errors, threshold_ratio)
314
+ print(f"Cluster {cluster_id} || Reconstruction Threshold: {round(threshold, 5)}")
315
+
316
+ # Print empty line after last cluster
317
+ if cluster_id == len(self.reconstruction_eval) - 1:
318
+ print("")
319
+
320
+ return threshold
321
+
322
+ def _detect_anomalies_all_clusters(self,
323
+ threshold_ratio: float) -> pd.DataFrame:
324
+ """Detect anomalies across all clusters."""
325
+ all_predictions = []
326
+
327
+ for cluster_id in self.reconstruction_eval.keys():
328
+ if cluster_id not in self.reconstruction_final:
329
+ continue
330
+
331
+ cluster_preds = self._detect_cluster_anomalies(
332
+ cluster_id,
333
+ threshold_ratio
334
+ )
335
+ all_predictions.append(cluster_preds)
336
+
337
+ # Aggregate and sort by original index
338
+ if not all_predictions:
339
+ raise ValidationError("No cluster predictions available")
340
+
341
+ final_predictions = pd.concat(all_predictions, ignore_index=True)
342
+ final_predictions = final_predictions.sort_values(by='_index')
343
+ final_predictions = final_predictions.reset_index(drop=True)
344
+
345
+ return final_predictions.drop('_index', axis=1)
346
+
347
+ def _detect_cluster_anomalies(self,
348
+ cluster_id: int,
349
+ threshold_ratio: float) -> pd.DataFrame:
350
+ """Detect anomalies for a single cluster."""
351
+ # Compute threshold (without printing)
352
+ rec_errors = self.reconstruction_eval[cluster_id]["metrics"][self._metric_column]
353
+ threshold = compute_threshold(rec_errors, threshold_ratio)
354
+
355
+ # Get predictions
356
+ recons_metrics = self.reconstruction_final[cluster_id]["metrics"]
357
+ predictions = self.preds_final[cluster_id]["predictions"].copy()
358
+ indices = self.preds_final[cluster_id]["indexs"]
359
+
360
+ # Detect anomalies
361
+ predictions['Predicted Anomalies'] = detect_anomalies(
362
+ recons_metrics[self._metric_column],
363
+ threshold
364
+ )
365
+ predictions['_index'] = indices
366
+
367
+ return predictions
368
+
369
+ def _aggregate_evaluation_results(self,
370
+ cluster_results: list,
371
+ threshold_ratio: float) -> Dict[str, Any]:
372
+ """Aggregate cluster evaluation results into global metrics."""
373
+ # Cluster-level metrics
374
+ cluster_metrics = pd.concat(
375
+ [result["metrics"] for result in cluster_results],
376
+ ignore_index=True
377
+ )
378
+
379
+ # Confusion matrices
380
+ confusion_matrices = {
381
+ result["cluster_id"]: {
382
+ f"CM_{result['cluster_id']}": result["confusion_matrix"]
383
+ }
384
+ for result in cluster_results
385
+ }
386
+
387
+ # Global predictions
388
+ all_predictions = []
389
+ for result in cluster_results:
390
+ pred_df = pd.DataFrame({
391
+ 'index': result["indices"],
392
+ 'y_test': result["y_test"],
393
+ 'Predicted Anomalies': result["predictions"]['Predicted Anomalies']
394
+ })
395
+ all_predictions.append(pred_df)
396
+
397
+ ytpred = pd.concat(all_predictions, ignore_index=True)
398
+ ytpred = ytpred.sort_values(by='index').set_index('index')
399
+
400
+ # Global metrics
401
+ global_metrics = metrics_classification(
402
+ ytpred['y_test'],
403
+ ytpred['Predicted Anomalies']
404
+ )
405
+ global_metrics["Threshold Metric"] = self._threshold_metric.value.upper()
406
+ global_metrics["Threshold Ratio"] = threshold_ratio
407
+
408
+ return {
409
+ "global metrics": global_metrics,
410
+ "clusters metrics": cluster_metrics,
411
+ "confusion matrix": confusion_matrices,
412
+ "y_true vs y_pred": ytpred
413
+ }
414
+
415
+ # Storage and retrieval methods
416
+
417
+ def _store_reconstruction_results(self, cluster_results: Dict) -> None:
418
+ """Store reconstruction results based on current phase."""
419
+ # Convert cluster results to metrics format
420
+ metrics_results = {}
421
+ preds_results = {}
422
+
423
+ for cluster_id, result in cluster_results.items():
424
+ metrics_results[cluster_id] = {
425
+ "cluster": result["cluster"],
426
+ "metrics": result["errors"]["metrics"],
427
+ "column_metrics": result["errors"]["column_metrics"],
428
+ "indexs": result["indexs"]
429
+ }
430
+
431
+ preds_results[cluster_id] = {
432
+ "cluster": result["cluster"],
433
+ "real": result["real"],
434
+ "y_true": result["y_true"],
435
+ "predictions": result["predictions"],
436
+ "indexs": result["indexs"]
437
+ }
438
+
439
+ if self._phase == PhaseType.EVALUATION:
440
+ self.preds_train = preds_results
441
+ self.reconstruction_eval = metrics_results
442
+ elif self._phase == PhaseType.TESTING:
443
+ self.preds_test = preds_results
444
+ self.reconstruction_test = metrics_results
445
+ elif self._phase == PhaseType.PREDICTION:
446
+ self.preds_final = preds_results
447
+ self.reconstruction_final = metrics_results
448
+
449
+ def _return_based_on_phase(self) -> Union['SegmentAE', tuple]:
450
+ """Return appropriate results based on phase."""
451
+ if self._phase == PhaseType.EVALUATION:
452
+ return self
453
+ elif self._phase == PhaseType.TESTING:
454
+ return self.preds_test, self.reconstruction_test
455
+ elif self._phase == PhaseType.PREDICTION:
456
+ return self.preds_final, self.reconstruction_final
457
+
458
+ # Helper methods
459
+
460
+ def _get_cluster_predictions(self, input_data: pd.DataFrame) -> pd.DataFrame:
461
+ """Get cluster assignments for input data."""
462
+ return self.cl_model.cluster_prediction(X=input_data)
463
+
464
+ def _validate_autoencoder(self, ae_model: Any) -> None:
465
+ """Validate that autoencoder has required interface."""
466
+ if not hasattr(ae_model, 'predict'):
467
+ raise ConfigurationError(
468
+ "Autoencoder must have a 'predict' method. "
469
+ "Ensure you're passing a trained Keras model or built-in autoencoder."
470
+ )
471
+
472
+ def _validate_clustering(self, cl_model: Clustering) -> None:
473
+ """Validate clustering model."""
474
+ if not isinstance(cl_model, Clustering):
475
+ raise ConfigurationError(
476
+ "cl_model must be an instance of Clustering",
477
+ valid_options=["Clustering"]
478
+ )
479
+ if not cl_model._is_fitted:
480
+ raise ModelNotFittedError(
481
+ component="Clustering",
482
+ message="Clustering model must be fitted before use. "
483
+ "Call clustering_fit(X) method first."
484
+ )
485
+
486
+ def _validate_reconstruction_input(self,
487
+ input_data: pd.DataFrame,
488
+ target_col: Optional[pd.Series]) -> None:
489
+ """Validate reconstruction inputs."""
490
+ self._validate_input(input_data, "Input data")
491
+
492
+ if target_col is not None:
493
+ if not isinstance(target_col, pd.Series):
494
+ raise ValidationError(
495
+ f"target_col must be a pandas Series, got {type(target_col).__name__}",
496
+ suggestion="Convert to Series using pd.Series() or use DataFrame column"
497
+ )
498
+ if len(target_col) != len(input_data):
499
+ raise ValidationError(
500
+ f"target_col length ({len(target_col)}) must match "
501
+ f"input_data length ({len(input_data)})"
502
+ )
503
+
504
+ def _validate_evaluation_input(self,
505
+ input_data: pd.DataFrame,
506
+ target_col: pd.Series) -> None:
507
+ """Validate evaluation inputs."""
508
+ self._validate_input(input_data, "Input data")
509
+
510
+ if not isinstance(target_col, pd.Series):
511
+ raise ValidationError(
512
+ f"target_col must be a pandas Series, got {type(target_col).__name__}",
513
+ suggestion="Use test[target_column] to extract Series"
514
+ )
515
+
516
+ if len(target_col) != len(input_data):
517
+ raise ValidationError(
518
+ f"target_col length ({len(target_col)}) must match "
519
+ f"input_data length ({len(input_data)})"
520
+ )
521
+
522
+ def _validate_input(self, X: pd.DataFrame, context: str = "Input") -> None:
523
+ """Validate input DataFrame."""
524
+ if not isinstance(X, pd.DataFrame):
525
+ raise ValidationError(
526
+ f"{context} must be a pandas DataFrame, got {type(X).__name__}",
527
+ suggestion="Convert to DataFrame using pd.DataFrame()"
528
+ )
529
+
530
+ if X.empty:
531
+ raise ValidationError(
532
+ f"{context} DataFrame is empty",
533
+ suggestion="Ensure your dataset contains data"
534
+ )
535
+
536
+ def _validate_fitted_for_evaluation(self) -> None:
537
+ """Validate that reconstruction has been performed."""
538
+ if not self._is_reconstruction_fitted:
539
+ raise ModelNotFittedError(
540
+ component="SegmentAE",
541
+ message="Must call reconstruction() before evaluation(). "
542
+ "Example: sg.reconstruction(X_train, threshold_metric='mse')"
543
+ )
544
+
545
+ def _validate_fitted_for_detection(self) -> None:
546
+ """Validate that reconstruction has been performed."""
547
+ if not self._is_reconstruction_fitted:
548
+ raise ModelNotFittedError(
549
+ component="SegmentAE",
550
+ message="Must call reconstruction() before detections(). "
551
+ "Example: sg.reconstruction(X_train, threshold_metric='mse')"
552
+ )
553
+
554
+ def __repr__(self) -> str:
555
+ """String representation of SegmentAE."""
556
+ ae_name = type(self.ae_model).__name__
557
+ return (
558
+ f"SegmentAE("
559
+ f"autoencoder={ae_name}, "
560
+ f"clustering={self.cl_model.cluster_model}, "
561
+ f"fitted={self._is_reconstruction_fitted})"
562
+ )
@@ -0,0 +1,21 @@
1
+ from segmentae.pipeline.reconstruction import (
2
+ ClusterReconstruction,
3
+ ReconstructionMetrics,
4
+ compute_column_metrics,
5
+ compute_reconstruction_errors,
6
+ compute_threshold,
7
+ detect_anomalies,
8
+ )
9
+ from segmentae.pipeline.segmentae import EvaluationConfig, ReconstructionConfig, SegmentAE
10
+
11
+ __all__ = [
12
+ 'SegmentAE',
13
+ 'ReconstructionConfig',
14
+ 'EvaluationConfig',
15
+ 'ClusterReconstruction',
16
+ 'ReconstructionMetrics',
17
+ 'compute_reconstruction_errors',
18
+ 'compute_column_metrics',
19
+ 'compute_threshold',
20
+ 'detect_anomalies'
21
+ ]