graflag-evaluator 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,14 @@
1
+ Metadata-Version: 2.4
2
+ Name: graflag_evaluator
3
+ Version: 1.0.0
4
+ Summary: Evaluation framework for graph anomaly detection methods in GraFlag
5
+ Author: GraFlag Team
6
+ Requires-Python: >=3.7
7
+ Requires-Dist: numpy>=1.21.0
8
+ Requires-Dist: scikit-learn>=1.0.0
9
+ Requires-Dist: matplotlib>=3.5.0
10
+ Requires-Dist: pandas>=1.3.0
11
+ Dynamic: author
12
+ Dynamic: requires-dist
13
+ Dynamic: requires-python
14
+ Dynamic: summary
@@ -0,0 +1,132 @@
1
+ # GraFlag Evaluator
2
+
3
+ Docker-based evaluation system for graph anomaly detection experiments.
4
+
5
+ ## Features
6
+
7
+ - Automatic metric computation based on result type
8
+ - Plot generation: ROC curves, PR curves, score distributions, spot curves
9
+ - Spot file integration: detects and plots training/validation metrics
10
+ - Standardized output: evaluation.json with all metrics and metadata
11
+ - Docker-based: isolated environment with all dependencies
12
+
13
+ ## Usage
14
+
15
+ ### From CLI
16
+
17
+ ```bash
18
+ # Evaluate an experiment (builds Docker image on first run)
19
+ graflag evaluate -e exp__generaldyg__btc_alpha__20251211_120000
20
+
21
+ # Copy results locally
22
+ graflag copy --from-remote -s experiments/<exp_name>/eval -d ./eval_results
23
+ ```
24
+
25
+ ### Manual Docker Usage
26
+
27
+ ```bash
28
+ # Build image (done automatically by CLI)
29
+ cd graflag-shared/libs/graflag_evaluator
30
+ docker build -t graflag-evaluator:latest .
31
+
32
+ # Run evaluation
33
+ docker run --rm -v /shared:/shared graflag-evaluator:latest /shared/experiments/<exp_name>
34
+ ```
35
+
36
+ ### From Python
37
+
38
+ ```python
39
+ from graflag_evaluator import Evaluator
40
+ from pathlib import Path
41
+
42
+ evaluator = Evaluator(Path("experiments/exp__generaldyg__btc_alpha__20251211_120000"))
43
+ eval_path = evaluator.evaluate()
44
+ ```
45
+
46
+ ## Supported Metrics
47
+
48
+ All result types get:
49
+ - **AUC-ROC**: Area under ROC curve
50
+ - **AUC-PR**: Area under Precision-Recall curve
51
+ - **Precision@K**: Precision at top K predictions
52
+ - **Recall@K**: Recall at top K predictions
53
+ - **F1@K**: F1 score at top K
54
+ - **Best F1**: Best F1 across all thresholds
55
+
56
+ Additional metrics are computed based on result type (edge counts, temporal span, etc.).
57
+
58
+ ## Output Structure
59
+
60
+ ```
61
+ experiments/exp_name/
62
+ +-- results.json (input)
63
+ +-- training.csv (optional spot file)
64
+ +-- validation.csv (optional spot file)
65
+ +-- eval/
66
+ +-- evaluation.json (computed metrics)
67
+ +-- roc_curve.png
68
+ +-- pr_curve.png
69
+ +-- score_distribution.png
70
+ +-- spot_curves.png (if spot files exist)
71
+ ```
72
+
73
+ ### evaluation.json Format
74
+
75
+ ```json
76
+ {
77
+ "experiment_name": "exp__generaldyg__btc_alpha__20251211_120000",
78
+ "result_type": "EDGE_STREAM_ANOMALY_SCORES",
79
+ "metrics": {
80
+ "auc_roc": 0.9234,
81
+ "auc_pr": 0.8765,
82
+ "precision_at_k": 0.8500,
83
+ "recall_at_k": 0.8500,
84
+ "f1_at_k": 0.8500,
85
+ "best_f1": 0.8723,
86
+ "best_f1_threshold": 0.5432,
87
+ "num_anomalies": 345,
88
+ "num_samples": 3783,
89
+ "anomaly_ratio": 0.0912
90
+ },
91
+ "plots": {
92
+ "roc_curve": "roc_curve.png",
93
+ "pr_curve": "pr_curve.png",
94
+ "score_distribution": "score_distribution.png",
95
+ "spot_curves": "spot_curves.png"
96
+ },
97
+ "spot_files": ["training", "validation"]
98
+ }
99
+ ```
100
+
101
+ ## Adding Custom Metrics
102
+
103
+ ```python
104
+ from graflag_evaluator.metrics import MetricCalculator
105
+
106
+ def compute_custom_metric(scores, ground_truth, **kwargs):
107
+ return {"custom_metric": 0.123}
108
+
109
+ MetricCalculator.register_metric(
110
+ "EDGE_STREAM_ANOMALY_SCORES",
111
+ compute_custom_metric
112
+ )
113
+ ```
114
+
115
+ ## Architecture
116
+
117
+ ```
118
+ graflag_evaluator/
119
+ +-- __init__.py Package exports
120
+ +-- evaluator.py Main orchestrator
121
+ +-- metrics.py Metric calculators with registry
122
+ +-- plots.py Plot generation utilities
123
+ +-- run_evaluation.py Docker container entry point
124
+ ```
125
+
126
+ ## Troubleshooting
127
+
128
+ **"results.json not found"** -- Experiment hasn't completed or failed before writing results.
129
+
130
+ **"No ground_truth found"** -- The results.json must include ground_truth for evaluation.
131
+
132
+ **"Only one class present"** -- Dataset has no anomalies or only anomalies. Check data preparation.
@@ -0,0 +1,8 @@
1
+ """GraFlag Evaluator - Modular evaluation system for graph anomaly detection."""
2
+
3
+ from .metrics import MetricCalculator, get_metrics_for_type
4
+ from .evaluator import Evaluator
5
+
6
+ __all__ = ["MetricCalculator", "get_metrics_for_type", "Evaluator"]
7
+
8
+ __version__ = "1.0.0"
@@ -0,0 +1,215 @@
1
+ """Main evaluator orchestrator."""
2
+
3
+ import json
4
+ import numpy as np
5
+ from pathlib import Path
6
+ from typing import Dict, Any, Optional
7
+ import logging
8
+
9
+ from .metrics import MetricCalculator
10
+ from .plots import PlotGenerator
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class Evaluator:
16
+ """
17
+ Main evaluation orchestrator for GraFlag experiments.
18
+
19
+ Automatically:
20
+ 1. Loads results.json from experiment directory
21
+ 2. Detects result type and loads appropriate data
22
+ 3. Computes all relevant metrics
23
+ 4. Generates evaluation plots (ROC, PR, spot curves)
24
+ 5. Saves evaluation.json with all metrics and metadata
25
+ """
26
+
27
+ def __init__(self, experiment_path: Path):
28
+ """
29
+ Initialize evaluator for an experiment.
30
+
31
+ Args:
32
+ experiment_path: Path to experiment directory
33
+ """
34
+ self.experiment_path = Path(experiment_path)
35
+ self.results_path = self.experiment_path / "results.json"
36
+ self.eval_dir = self.experiment_path / "eval"
37
+
38
+ if not self.results_path.exists():
39
+ raise FileNotFoundError(f"results.json not found in {self.experiment_path}")
40
+
41
+ # Create eval directory
42
+ self.eval_dir.mkdir(exist_ok=True)
43
+
44
+ # Load results
45
+ with open(self.results_path, 'r') as f:
46
+ self.results = json.load(f)
47
+
48
+ self.result_type = self.results.get("result_type")
49
+ if not self.result_type:
50
+ raise ValueError("result_type not found in results.json")
51
+
52
+ logger.info(f"[INFO] Evaluating experiment: {self.experiment_path.name}")
53
+ logger.info(f" Result type: {self.result_type}")
54
+
55
+ def _load_scores_and_ground_truth(self) -> tuple:
56
+ """Load scores and ground truth from results."""
57
+ scores_raw = self.results.get("scores", [])
58
+ ground_truth_raw = self.results.get("ground_truth", [])
59
+
60
+ if len(scores_raw) == 0:
61
+ raise ValueError("No scores found in results.json")
62
+ if len(ground_truth_raw) == 0:
63
+ raise ValueError("No ground_truth found in results.json")
64
+
65
+ # Handle ragged arrays (e.g., TEMPORAL_EDGE_ANOMALY_SCORES where each
66
+ # snapshot has different number of edges). Use dtype=object for ragged.
67
+ try:
68
+ scores = np.array(scores_raw)
69
+ except ValueError:
70
+ # Ragged array - use object dtype
71
+ scores = np.array(scores_raw, dtype=object)
72
+
73
+ try:
74
+ ground_truth = np.array(ground_truth_raw)
75
+ except ValueError:
76
+ # Ragged array - use object dtype
77
+ ground_truth = np.array(ground_truth_raw, dtype=object)
78
+
79
+ logger.info(f" Scores shape: {scores.shape}, dtype: {scores.dtype}")
80
+ logger.info(f" Ground truth shape: {ground_truth.shape}, dtype: {ground_truth.dtype}")
81
+
82
+ return scores, ground_truth
83
+
84
+ def _find_spot_files(self) -> Dict[str, Path]:
85
+ """Find all spot CSV files in experiment directory."""
86
+ spot_files = {}
87
+ for csv_file in self.experiment_path.glob("*.csv"):
88
+ metric_key = csv_file.stem # filename without extension
89
+ spot_files[metric_key] = csv_file
90
+
91
+ if spot_files:
92
+ logger.info(f" Found {len(spot_files)} spot files: {list(spot_files.keys())}")
93
+
94
+ return spot_files
95
+
96
+ def compute_metrics(self) -> Dict[str, Any]:
97
+ """
98
+ Compute all metrics for the experiment.
99
+
100
+ Returns:
101
+ Dictionary of computed metrics
102
+ """
103
+ scores, ground_truth = self._load_scores_and_ground_truth()
104
+
105
+ # Get additional data (timestamps, edges, etc.)
106
+ kwargs = {
107
+ "timestamps": self.results.get("timestamps"),
108
+ "edges": self.results.get("edges"),
109
+ "node_ids": self.results.get("node_ids"),
110
+ "graph_ids": self.results.get("graph_ids"),
111
+ }
112
+
113
+ # Compute metrics
114
+ logger.info("[INFO] Computing metrics...")
115
+ metrics = MetricCalculator.calculate_metrics(
116
+ self.result_type, scores, ground_truth, **kwargs
117
+ )
118
+
119
+ logger.info(f"[OK] Computed {len(metrics)} metrics")
120
+ return metrics
121
+
122
+ def generate_plots(self) -> list:
123
+ """Generate all evaluation plots.
124
+
125
+ Returns:
126
+ List of generated spot curve plot filenames
127
+ """
128
+ logger.info("[INFO] Generating plots...")
129
+
130
+ scores, ground_truth = self._load_scores_and_ground_truth()
131
+
132
+ # ROC curve
133
+ roc_path = self.eval_dir / "roc_curve.png"
134
+ PlotGenerator.plot_roc_curve(scores, ground_truth, roc_path,
135
+ title=f"ROC Curve - {self.experiment_path.name}")
136
+
137
+ # PR curve
138
+ pr_path = self.eval_dir / "pr_curve.png"
139
+ PlotGenerator.plot_pr_curve(scores, ground_truth, pr_path,
140
+ title=f"PR Curve - {self.experiment_path.name}")
141
+
142
+ # Score distribution
143
+ dist_path = self.eval_dir / "score_distribution.png"
144
+ PlotGenerator.plot_score_distribution(scores, ground_truth, dist_path,
145
+ title=f"Score Distribution - {self.experiment_path.name}")
146
+
147
+ # Spot curves from spot files (generates separate files)
148
+ spot_files = self._find_spot_files()
149
+ spot_plot_files = []
150
+ if spot_files:
151
+ spot_plot_files = PlotGenerator.plot_spot_curves(
152
+ spot_files, self.eval_dir,
153
+ title=f"Spot Curves - {self.experiment_path.name}"
154
+ )
155
+
156
+ logger.info(f"[OK] Plots saved to {self.eval_dir}")
157
+ return spot_plot_files
158
+
159
+ def evaluate(self) -> Path:
160
+ """
161
+ Run full evaluation: compute metrics and generate plots.
162
+
163
+ Returns:
164
+ Path to evaluation.json
165
+ """
166
+ # Compute metrics
167
+ computed_metrics = self.compute_metrics()
168
+
169
+ # Generate plots (returns list of spot curve plot filenames)
170
+ spot_plot_files = self.generate_plots()
171
+
172
+ # Build evaluation results
173
+ evaluation = {
174
+ "experiment_name": self.experiment_path.name,
175
+ "result_type": self.result_type,
176
+ "metrics": computed_metrics,
177
+ "metadata": self.results.get("metadata", {}),
178
+ "plots": {
179
+ "roc_curve": "roc_curve.png",
180
+ "pr_curve": "pr_curve.png",
181
+ "score_distribution": "score_distribution.png",
182
+ },
183
+ }
184
+
185
+ # Add spot curve plots if available
186
+ spot_files = self._find_spot_files()
187
+ if spot_files:
188
+ evaluation["spot_files"] = list(spot_files.keys())
189
+ # Add each spot curve plot
190
+ for plot_file in spot_plot_files:
191
+ plot_key = plot_file.replace('.png', '')
192
+ evaluation["plots"][plot_key] = plot_file
193
+
194
+ # Save evaluation.json
195
+ eval_json_path = self.eval_dir / "evaluation.json"
196
+ with open(eval_json_path, 'w') as f:
197
+ json.dump(evaluation, f, indent=2)
198
+
199
+ logger.info(f"[OK] Evaluation complete!")
200
+ logger.info(f" Results: {eval_json_path}")
201
+ logger.info(f" Plots: {self.eval_dir}")
202
+
203
+ # Print summary
204
+ print("\n" + "="*60)
205
+ print(f"[INFO] Evaluation Summary: {self.experiment_path.name}")
206
+ print("="*60)
207
+ print(f"Result Type: {self.result_type}")
208
+ print(f"\nKey Metrics:")
209
+ for key, value in computed_metrics.items():
210
+ if isinstance(value, (int, float)) and value is not None:
211
+ print(f" {key}: {value}")
212
+ print(f"\nPlots saved to: {self.eval_dir}")
213
+ print("="*60 + "\n")
214
+
215
+ return eval_json_path
@@ -0,0 +1,14 @@
1
+ Metadata-Version: 2.4
2
+ Name: graflag_evaluator
3
+ Version: 1.0.0
4
+ Summary: Evaluation framework for graph anomaly detection methods in GraFlag
5
+ Author: GraFlag Team
6
+ Requires-Python: >=3.7
7
+ Requires-Dist: numpy>=1.21.0
8
+ Requires-Dist: scikit-learn>=1.0.0
9
+ Requires-Dist: matplotlib>=3.5.0
10
+ Requires-Dist: pandas>=1.3.0
11
+ Dynamic: author
12
+ Dynamic: requires-dist
13
+ Dynamic: requires-python
14
+ Dynamic: summary
@@ -0,0 +1,17 @@
1
+ README.md
2
+ __init__.py
3
+ evaluator.py
4
+ metrics.py
5
+ plots.py
6
+ run_evaluation.py
7
+ setup.py
8
+ ./__init__.py
9
+ ./evaluator.py
10
+ ./metrics.py
11
+ ./plots.py
12
+ ./run_evaluation.py
13
+ graflag_evaluator.egg-info/PKG-INFO
14
+ graflag_evaluator.egg-info/SOURCES.txt
15
+ graflag_evaluator.egg-info/dependency_links.txt
16
+ graflag_evaluator.egg-info/requires.txt
17
+ graflag_evaluator.egg-info/top_level.txt
@@ -0,0 +1,4 @@
1
+ numpy>=1.21.0
2
+ scikit-learn>=1.0.0
3
+ matplotlib>=3.5.0
4
+ pandas>=1.3.0
@@ -0,0 +1 @@
1
+ graflag_evaluator
@@ -0,0 +1,241 @@
1
+ """Metric calculators for different result types."""
2
+
3
+ import numpy as np
4
+ from sklearn import metrics
5
+ from typing import Dict, List, Any, Callable
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class MetricCalculator:
12
+ """
13
+ Base class for metric calculation.
14
+
15
+ Supports plugin-based architecture for adding new metrics.
16
+ """
17
+
18
+ # Registry of metric functions by result type
19
+ _METRIC_REGISTRY: Dict[str, List[Callable]] = {}
20
+
21
+ @classmethod
22
+ def register_metric(cls, result_type: str, metric_func: Callable):
23
+ """
24
+ Register a new metric function for a result type.
25
+
26
+ Args:
27
+ result_type: Result type (e.g., "EDGE_STREAM_ANOMALY_SCORES")
28
+ metric_func: Function that takes (scores, ground_truth, **kwargs)
29
+ and returns Dict[str, float]
30
+ """
31
+ if result_type not in cls._METRIC_REGISTRY:
32
+ cls._METRIC_REGISTRY[result_type] = []
33
+ cls._METRIC_REGISTRY[result_type].append(metric_func)
34
+ logger.debug(f"Registered metric {metric_func.__name__} for {result_type}")
35
+
36
+ @classmethod
37
+ def calculate_metrics(cls, result_type: str, scores: np.ndarray,
38
+ ground_truth: np.ndarray, **kwargs) -> Dict[str, Any]:
39
+ """
40
+ Calculate all registered metrics for a result type.
41
+
42
+ Args:
43
+ result_type: Type of anomaly detection result
44
+ scores: Anomaly scores
45
+ ground_truth: Ground truth labels
46
+ **kwargs: Additional parameters (timestamps, edges, etc.)
47
+
48
+ Returns:
49
+ Dictionary of computed metrics
50
+ """
51
+ if result_type not in cls._METRIC_REGISTRY:
52
+ logger.warning(f"No metrics registered for {result_type}")
53
+ return {}
54
+
55
+ all_metrics = {}
56
+ for metric_func in cls._METRIC_REGISTRY[result_type]:
57
+ try:
58
+ result = metric_func(scores, ground_truth, **kwargs)
59
+ all_metrics.update(result)
60
+ except Exception as e:
61
+ logger.error(f"Error in {metric_func.__name__}: {e}")
62
+
63
+ return all_metrics
64
+
65
+
66
+ # ============================================================================
67
+ # Standard Metrics for Binary Anomaly Detection
68
+ # ============================================================================
69
+
70
+ def compute_classification_metrics(scores: np.ndarray, ground_truth: np.ndarray,
71
+ **kwargs) -> Dict[str, float]:
72
+ """
73
+ Compute standard classification metrics (works for all types).
74
+
75
+ Metrics:
76
+ - AUC-ROC: Area under ROC curve
77
+ - AUC-PR: Area under Precision-Recall curve
78
+ - Precision@K: Precision in top K predictions
79
+ - Recall@K: Recall in top K predictions
80
+ - F1@K: F1 score in top K predictions
81
+ - Best F1: Best F1 score across all thresholds
82
+ """
83
+ # Handle nested lists (e.g., TEMPORAL_EDGE_ANOMALY_SCORES where each snapshot
84
+ # has different number of edges). np.array() creates object array for ragged lists.
85
+ if scores.dtype == object or (scores.ndim == 1 and isinstance(scores[0], (list, np.ndarray))):
86
+ # Flatten nested structure
87
+ scores_flat = np.concatenate([np.asarray(s).flatten() for s in scores])
88
+ gt_flat = np.concatenate([np.asarray(g).flatten() for g in ground_truth])
89
+ else:
90
+ scores_flat = scores.flatten()
91
+ gt_flat = ground_truth.flatten()
92
+
93
+ # Remove invalid scores (-2, -1) if present
94
+ valid_mask = (scores_flat >= 0) & (scores_flat <= 1) if np.max(scores_flat) <= 1 else scores_flat > -2
95
+ scores_valid = scores_flat[valid_mask]
96
+ gt_valid = gt_flat[valid_mask]
97
+
98
+ if len(np.unique(gt_valid)) < 2:
99
+ logger.warning("Ground truth has only one class, skipping some metrics")
100
+ return {"auc_roc": None, "auc_pr": None}
101
+
102
+ # AUC-ROC
103
+ auc_roc = metrics.roc_auc_score(gt_valid, scores_valid)
104
+
105
+ # AUC-PR
106
+ precision, recall, _ = metrics.precision_recall_curve(gt_valid, scores_valid)
107
+ auc_pr = metrics.auc(recall, precision)
108
+
109
+ # Precision/Recall/F1 at K (K = number of anomalies)
110
+ k = int(np.sum(gt_valid))
111
+ top_k_indices = np.argsort(scores_valid)[-k:]
112
+ predictions_at_k = np.zeros_like(gt_valid)
113
+ predictions_at_k[top_k_indices] = 1
114
+
115
+ precision_at_k = metrics.precision_score(gt_valid, predictions_at_k, zero_division=0)
116
+ recall_at_k = metrics.recall_score(gt_valid, predictions_at_k, zero_division=0)
117
+ f1_at_k = metrics.f1_score(gt_valid, predictions_at_k, zero_division=0)
118
+
119
+ # Best F1 across all thresholds
120
+ f1_scores = 2 * (precision * recall) / (precision + recall + 1e-10)
121
+ best_f1 = np.max(f1_scores)
122
+ best_f1_threshold = _[np.argmax(f1_scores)] if len(_) > 0 else None
123
+
124
+ return {
125
+ "auc_roc": round(float(auc_roc), 4),
126
+ "auc_pr": round(float(auc_pr), 4),
127
+ "precision_at_k": round(float(precision_at_k), 4),
128
+ "recall_at_k": round(float(recall_at_k), 4),
129
+ "f1_at_k": round(float(f1_at_k), 4),
130
+ "best_f1": round(float(best_f1), 4),
131
+ "best_f1_threshold": round(float(best_f1_threshold), 4) if best_f1_threshold else None,
132
+ "num_anomalies": int(k),
133
+ "num_samples": int(len(gt_valid)),
134
+ "anomaly_ratio": round(float(k / len(gt_valid)), 4),
135
+ }
136
+
137
+
138
+ def compute_temporal_metrics(scores: np.ndarray, ground_truth: np.ndarray,
139
+ timestamps: List[int] = None, **kwargs) -> Dict[str, float]:
140
+ """
141
+ Compute temporal-specific metrics.
142
+
143
+ Metrics:
144
+ - Early detection rate: How early anomalies are detected
145
+ - Temporal consistency: How consistent scores are over time
146
+ """
147
+ if timestamps is None:
148
+ return {}
149
+
150
+ # Early detection: average time between first high score and actual anomaly
151
+ # (This is a placeholder - implement based on your specific needs)
152
+
153
+ return {
154
+ "temporal_span": int(max(timestamps) - min(timestamps)) if timestamps else 0,
155
+ "num_timestamps": len(set(timestamps)) if timestamps else 0,
156
+ }
157
+
158
+
159
+ def compute_edge_metrics(scores: np.ndarray, ground_truth: np.ndarray,
160
+ edges: List[List[int]] = None, **kwargs) -> Dict[str, float]:
161
+ """
162
+ Compute edge-specific metrics.
163
+
164
+ Metrics:
165
+ - Number of unique edges
166
+ - Edge degree distribution stats
167
+ """
168
+ if edges is None:
169
+ return {}
170
+
171
+ # Count unique edges
172
+ unique_edges = len(set(tuple(e) for e in edges))
173
+
174
+ # Node degree stats (how many times each node appears)
175
+ nodes = [n for edge in edges for n in edge]
176
+ unique_nodes = len(set(nodes))
177
+
178
+ return {
179
+ "num_unique_edges": int(unique_edges),
180
+ "num_unique_nodes": int(unique_nodes),
181
+ "total_edge_occurrences": int(len(edges)),
182
+ }
183
+
184
+
185
+ # ============================================================================
186
+ # Register Default Metrics
187
+ # ============================================================================
188
+
189
+ # Register for all result types
190
+ for result_type in [
191
+ "NODE_ANOMALY_SCORES",
192
+ "EDGE_ANOMALY_SCORES",
193
+ "GRAPH_ANOMALY_SCORES",
194
+ "TEMPORAL_NODE_ANOMALY_SCORES",
195
+ "TEMPORAL_EDGE_ANOMALY_SCORES",
196
+ "TEMPORAL_GRAPH_ANOMALY_SCORES",
197
+ "NODE_STREAM_ANOMALY_SCORES",
198
+ "EDGE_STREAM_ANOMALY_SCORES",
199
+ "GRAPH_STREAM_ANOMALY_SCORES",
200
+ ]:
201
+ MetricCalculator.register_metric(result_type, compute_classification_metrics)
202
+
203
+ # Register temporal metrics for temporal and stream types
204
+ for result_type in [
205
+ "TEMPORAL_NODE_ANOMALY_SCORES",
206
+ "TEMPORAL_EDGE_ANOMALY_SCORES",
207
+ "TEMPORAL_GRAPH_ANOMALY_SCORES",
208
+ "NODE_STREAM_ANOMALY_SCORES",
209
+ "EDGE_STREAM_ANOMALY_SCORES",
210
+ "GRAPH_STREAM_ANOMALY_SCORES",
211
+ ]:
212
+ MetricCalculator.register_metric(result_type, compute_temporal_metrics)
213
+
214
+ # Register edge metrics for edge types
215
+ for result_type in [
216
+ "EDGE_ANOMALY_SCORES",
217
+ "TEMPORAL_EDGE_ANOMALY_SCORES",
218
+ "EDGE_STREAM_ANOMALY_SCORES",
219
+ ]:
220
+ MetricCalculator.register_metric(result_type, compute_edge_metrics)
221
+
222
+
223
+ def get_metrics_for_type(result_type: str) -> List[str]:
224
+ """
225
+ Get list of available metrics for a result type.
226
+
227
+ Args:
228
+ result_type: Result type string
229
+
230
+ Returns:
231
+ List of metric names
232
+ """
233
+ if result_type not in MetricCalculator._METRIC_REGISTRY:
234
+ return []
235
+
236
+ # Extract metric names from registered functions
237
+ metric_names = []
238
+ for func in MetricCalculator._METRIC_REGISTRY[result_type]:
239
+ metric_names.append(func.__name__)
240
+
241
+ return metric_names
@@ -0,0 +1,198 @@
1
+ """Plot generation utilities for evaluation."""
2
+
3
+ import numpy as np
4
+ import matplotlib
5
+ matplotlib.use('Agg') # Non-interactive backend
6
+ import matplotlib.pyplot as plt
7
+ from sklearn import metrics
8
+ from pathlib import Path
9
+ from typing import Dict, List, Optional
10
+ import pandas as pd
11
+ import logging
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def _flatten_ragged(arr: np.ndarray) -> np.ndarray:
17
+ """Flatten array, handling ragged/object arrays properly."""
18
+ if arr.dtype == object or (arr.ndim == 1 and len(arr) > 0 and isinstance(arr[0], (list, np.ndarray))):
19
+ # Ragged array - concatenate all elements
20
+ return np.concatenate([np.asarray(x).flatten() for x in arr])
21
+ return arr.flatten()
22
+
23
+
24
+ class PlotGenerator:
25
+ """Generate evaluation plots."""
26
+
27
+ @staticmethod
28
+ def plot_roc_curve(scores: np.ndarray, ground_truth: np.ndarray,
29
+ output_path: Path, title: str = "ROC Curve"):
30
+ """
31
+ Generate ROC curve plot.
32
+
33
+ Args:
34
+ scores: Anomaly scores
35
+ ground_truth: Ground truth labels
36
+ output_path: Path to save plot
37
+ title: Plot title
38
+ """
39
+ scores_flat = _flatten_ragged(scores)
40
+ gt_flat = _flatten_ragged(ground_truth)
41
+
42
+ # Remove invalid scores
43
+ valid_mask = scores_flat > -2
44
+ scores_valid = scores_flat[valid_mask]
45
+ gt_valid = gt_flat[valid_mask]
46
+
47
+ if len(np.unique(gt_valid)) < 2:
48
+ logger.warning("Cannot plot ROC: only one class present")
49
+ return
50
+
51
+ fpr, tpr, thresholds = metrics.roc_curve(gt_valid, scores_valid)
52
+ auc_score = metrics.auc(fpr, tpr)
53
+
54
+ plt.figure(figsize=(8, 6))
55
+ plt.plot(fpr, tpr, label=f'AUC = {auc_score:.4f}', linewidth=2)
56
+ plt.plot([0, 1], [0, 1], 'k--', label='Random', linewidth=1)
57
+ plt.xlabel('False Positive Rate', fontsize=12)
58
+ plt.ylabel('True Positive Rate', fontsize=12)
59
+ plt.title(title, fontsize=14)
60
+ plt.legend(fontsize=10)
61
+ plt.grid(alpha=0.3)
62
+ plt.tight_layout()
63
+ plt.savefig(output_path, dpi=150)
64
+ plt.close()
65
+
66
+ logger.info(f"[OK] ROC curve saved to {output_path}")
67
+
68
+ @staticmethod
69
+ def plot_pr_curve(scores: np.ndarray, ground_truth: np.ndarray,
70
+ output_path: Path, title: str = "Precision-Recall Curve"):
71
+ """
72
+ Generate Precision-Recall curve plot.
73
+
74
+ Args:
75
+ scores: Anomaly scores
76
+ ground_truth: Ground truth labels
77
+ output_path: Path to save plot
78
+ title: Plot title
79
+ """
80
+ scores_flat = _flatten_ragged(scores)
81
+ gt_flat = _flatten_ragged(ground_truth)
82
+
83
+ valid_mask = scores_flat > -2
84
+ scores_valid = scores_flat[valid_mask]
85
+ gt_valid = gt_flat[valid_mask]
86
+
87
+ if len(np.unique(gt_valid)) < 2:
88
+ logger.warning("Cannot plot PR: only one class present")
89
+ return
90
+
91
+ precision, recall, thresholds = metrics.precision_recall_curve(gt_valid, scores_valid)
92
+ auc_score = metrics.auc(recall, precision)
93
+
94
+ plt.figure(figsize=(8, 6))
95
+ plt.plot(recall, precision, label=f'AUC-PR = {auc_score:.4f}', linewidth=2)
96
+ plt.xlabel('Recall', fontsize=12)
97
+ plt.ylabel('Precision', fontsize=12)
98
+ plt.title(title, fontsize=14)
99
+ plt.legend(fontsize=10)
100
+ plt.grid(alpha=0.3)
101
+ plt.tight_layout()
102
+ plt.savefig(output_path, dpi=150)
103
+ plt.close()
104
+
105
+ logger.info(f"[OK] PR curve saved to {output_path}")
106
+
107
+ @staticmethod
108
+ def plot_score_distribution(scores: np.ndarray, ground_truth: np.ndarray,
109
+ output_path: Path, title: str = "Score Distribution"):
110
+ """
111
+ Generate score distribution plot (histogram for anomalies vs normal).
112
+
113
+ Args:
114
+ scores: Anomaly scores
115
+ ground_truth: Ground truth labels
116
+ output_path: Path to save plot
117
+ title: Plot title
118
+ """
119
+ scores_flat = _flatten_ragged(scores)
120
+ gt_flat = _flatten_ragged(ground_truth)
121
+
122
+ valid_mask = scores_flat > -2
123
+ scores_valid = scores_flat[valid_mask]
124
+ gt_valid = gt_flat[valid_mask]
125
+
126
+ normal_scores = scores_valid[gt_valid == 0]
127
+ anomaly_scores = scores_valid[gt_valid == 1]
128
+
129
+ plt.figure(figsize=(8, 6))
130
+ plt.hist(normal_scores, bins=50, alpha=0.5, label='Normal', color='blue')
131
+ plt.hist(anomaly_scores, bins=50, alpha=0.5, label='Anomaly', color='red')
132
+ plt.xlabel('Anomaly Score', fontsize=12)
133
+ plt.ylabel('Frequency', fontsize=12)
134
+ plt.title(title, fontsize=14)
135
+ plt.legend(fontsize=10)
136
+ plt.grid(alpha=0.3)
137
+ plt.tight_layout()
138
+ plt.savefig(output_path, dpi=150)
139
+ plt.close()
140
+
141
+ logger.info(f"[OK] Score distribution saved to {output_path}")
142
+
143
+ @staticmethod
144
+ def plot_spot_curves(spot_files: Dict[str, Path], output_dir: Path,
145
+ title: str = "Spot Curves") -> List[str]:
146
+ """
147
+ Generate separate spot curve plots from spot CSV files.
148
+
149
+ Args:
150
+ spot_files: Dictionary mapping metric_key to CSV path
151
+ output_dir: Directory to save plots
152
+ title: Plot title prefix
153
+
154
+ Returns:
155
+ List of generated plot filenames
156
+ """
157
+ if not spot_files:
158
+ logger.warning("No spot files to plot")
159
+ return []
160
+
161
+ generated_plots = []
162
+
163
+ for metric_key, csv_path in spot_files.items():
164
+ try:
165
+ df = pd.read_csv(csv_path)
166
+
167
+ # Get columns to plot (exclude timestamp and epoch)
168
+ plot_cols = [col for col in df.columns if col not in ('timestamp', 'epoch')]
169
+
170
+ if not plot_cols:
171
+ continue
172
+
173
+ # Create a separate plot for this spot file
174
+ plt.figure(figsize=(10, 6))
175
+
176
+ for col in plot_cols:
177
+ plt.plot(df.index, df[col], label=col, marker='o', markersize=3, linewidth=1.5)
178
+
179
+ plt.xlabel('Step/Epoch', fontsize=12)
180
+ plt.ylabel('Value', fontsize=12)
181
+ plt.title(f"{metric_key.replace('_', ' ').title()} Curves", fontsize=14)
182
+ plt.legend(fontsize=10, loc='best')
183
+ plt.grid(alpha=0.3)
184
+ plt.tight_layout()
185
+
186
+ # Save with metric_key name
187
+ output_filename = f"{metric_key}_curves.png"
188
+ output_path = output_dir / output_filename
189
+ plt.savefig(output_path, dpi=150)
190
+ plt.close()
191
+
192
+ generated_plots.append(output_filename)
193
+ logger.info(f"[OK] {metric_key} curves saved to {output_path}")
194
+
195
+ except Exception as e:
196
+ logger.error(f"Error plotting {metric_key}: {e}")
197
+
198
+ return generated_plots
@@ -0,0 +1,36 @@
1
+ #!/usr/bin/env python3
2
+ """Standalone script to run evaluation on an experiment."""
3
+
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ # Add graflag_evaluator to path
8
+ sys.path.insert(0, str(Path(__file__).parent.parent))
9
+
10
+ from graflag_evaluator import Evaluator
11
+
12
+ def main():
13
+ if len(sys.argv) < 2:
14
+ print("Usage: run_evaluation.py <experiment_directory>")
15
+ print("Example: run_evaluation.py /shared/experiments/exp_name")
16
+ sys.exit(1)
17
+
18
+ exp_dir = Path(sys.argv[1])
19
+
20
+ if not exp_dir.exists():
21
+ print(f"Error: Experiment directory not found: {exp_dir}")
22
+ sys.exit(1)
23
+
24
+ if not (exp_dir / "results.json").exists():
25
+ print(f"Error: results.json not found in {exp_dir}")
26
+ sys.exit(1)
27
+
28
+ # Run evaluation
29
+ print(f"[INFO] Loading experiment from: {exp_dir}")
30
+ evaluator = Evaluator(exp_dir)
31
+ eval_path = evaluator.evaluate()
32
+
33
+ print(f"\n[OK] Evaluation complete! Results saved to: {eval_path}")
34
+
35
+ if __name__ == "__main__":
36
+ main()
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,19 @@
1
+ """Setup script for graflag_evaluator package."""
2
+
3
+ from setuptools import setup
4
+
5
+ setup(
6
+ name="graflag_evaluator",
7
+ version="1.0.0",
8
+ description="Evaluation framework for graph anomaly detection methods in GraFlag",
9
+ author="GraFlag Team",
10
+ packages=["graflag_evaluator"],
11
+ package_dir={"graflag_evaluator": "."},
12
+ install_requires=[
13
+ "numpy>=1.21.0",
14
+ "scikit-learn>=1.0.0",
15
+ "matplotlib>=3.5.0",
16
+ "pandas>=1.3.0",
17
+ ],
18
+ python_requires=">=3.7",
19
+ )