graflag-bond 1.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graflag_bond-1.0.0/PKG-INFO +15 -0
- graflag_bond-1.0.0/__init__.py +15 -0
- graflag_bond-1.0.0/detectors.py +88 -0
- graflag_bond-1.0.0/graflag_bond.egg-info/PKG-INFO +15 -0
- graflag_bond-1.0.0/graflag_bond.egg-info/SOURCES.txt +14 -0
- graflag_bond-1.0.0/graflag_bond.egg-info/dependency_links.txt +1 -0
- graflag_bond-1.0.0/graflag_bond.egg-info/requires.txt +5 -0
- graflag_bond-1.0.0/graflag_bond.egg-info/top_level.txt +1 -0
- graflag_bond-1.0.0/setup.cfg +4 -0
- graflag_bond-1.0.0/setup.py +20 -0
- graflag_bond-1.0.0/train.py +233 -0
- graflag_bond-1.0.0/utils.py +186 -0
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: graflag_bond
|
|
3
|
+
Version: 1.0.0
|
|
4
|
+
Summary: Universal PyGOD detector wrapper for GraFlag BOND methods
|
|
5
|
+
Author: GraFlag Team
|
|
6
|
+
Requires-Python: >=3.7
|
|
7
|
+
Requires-Dist: torch>=2.0.0
|
|
8
|
+
Requires-Dist: torch-geometric>=2.3.0
|
|
9
|
+
Requires-Dist: pygod>=1.1.0
|
|
10
|
+
Requires-Dist: numpy>=1.24.0
|
|
11
|
+
Requires-Dist: scikit-learn>=1.3.0
|
|
12
|
+
Dynamic: author
|
|
13
|
+
Dynamic: requires-dist
|
|
14
|
+
Dynamic: requires-python
|
|
15
|
+
Dynamic: summary
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""
|
|
2
|
+
GraFlag Bond - Generic PyGOD Detector Wrapper
|
|
3
|
+
|
|
4
|
+
This library provides a unified interface for running PyGOD anomaly detection
|
|
5
|
+
methods through the GraFlag framework.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from .detectors import BondDetector
|
|
9
|
+
from .utils import get_all_parameters
|
|
10
|
+
|
|
11
|
+
__version__ = "1.0.0"
|
|
12
|
+
__all__ = [
|
|
13
|
+
"BondDetector",
|
|
14
|
+
"get_all_parameters"
|
|
15
|
+
]
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PyGOD Detector Enumeration
|
|
3
|
+
|
|
4
|
+
Dynamically discovers and maps all PyGOD detector classes.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import inspect
|
|
8
|
+
import pygod.detector
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BondDetector:
|
|
12
|
+
"""Dynamic PyGOD detector registry."""
|
|
13
|
+
|
|
14
|
+
_detectors = None
|
|
15
|
+
|
|
16
|
+
@classmethod
|
|
17
|
+
def _load_detectors(cls):
|
|
18
|
+
"""Load all detector classes from pygod.detector module."""
|
|
19
|
+
if cls._detectors is not None:
|
|
20
|
+
return
|
|
21
|
+
|
|
22
|
+
cls._detectors = {}
|
|
23
|
+
|
|
24
|
+
# Inspect pygod.detector module for all classes
|
|
25
|
+
for name, obj in inspect.getmembers(pygod.detector, inspect.isclass):
|
|
26
|
+
# Filter to only include classes defined in pygod.detector
|
|
27
|
+
if obj.__module__.startswith('pygod.detector'):
|
|
28
|
+
# Store with lowercase name as key
|
|
29
|
+
cls._detectors[name.lower()] = obj
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def from_method_name(cls, method_name: str):
|
|
33
|
+
"""
|
|
34
|
+
Get detector class from method name.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
method_name: Method name (e.g., 'bond_dominant', 'dominant', 'DOMINANT')
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Detector class name (lowercase)
|
|
41
|
+
|
|
42
|
+
Raises:
|
|
43
|
+
ValueError: If method name is not supported
|
|
44
|
+
"""
|
|
45
|
+
cls._load_detectors()
|
|
46
|
+
|
|
47
|
+
# Remove bond_ prefix if present and convert to lowercase
|
|
48
|
+
name = method_name.lower().replace("bond_", "")
|
|
49
|
+
|
|
50
|
+
if name not in cls._detectors:
|
|
51
|
+
supported = ", ".join(sorted(cls._detectors.keys()))
|
|
52
|
+
raise ValueError(f"Unsupported detector: {name}. Supported: {supported}")
|
|
53
|
+
|
|
54
|
+
return name
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def get_detector_class(cls, detector_name: str):
|
|
58
|
+
"""
|
|
59
|
+
Get the PyGOD detector class by name.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
detector_name: Detector name (e.g., 'dominant', 'adone')
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
PyGOD detector class
|
|
66
|
+
|
|
67
|
+
Raises:
|
|
68
|
+
ValueError: If detector name is not found
|
|
69
|
+
"""
|
|
70
|
+
cls._load_detectors()
|
|
71
|
+
|
|
72
|
+
name = detector_name.lower()
|
|
73
|
+
if name not in cls._detectors:
|
|
74
|
+
supported = ", ".join(sorted(cls._detectors.keys()))
|
|
75
|
+
raise ValueError(f"Detector not found: {name}. Available: {supported}")
|
|
76
|
+
|
|
77
|
+
return cls._detectors[name]
|
|
78
|
+
|
|
79
|
+
@classmethod
|
|
80
|
+
def list_detectors(cls):
|
|
81
|
+
"""
|
|
82
|
+
List all available detector names.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
List of detector names (lowercase)
|
|
86
|
+
"""
|
|
87
|
+
cls._load_detectors()
|
|
88
|
+
return sorted(cls._detectors.keys())
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: graflag_bond
|
|
3
|
+
Version: 1.0.0
|
|
4
|
+
Summary: Universal PyGOD detector wrapper for GraFlag BOND methods
|
|
5
|
+
Author: GraFlag Team
|
|
6
|
+
Requires-Python: >=3.7
|
|
7
|
+
Requires-Dist: torch>=2.0.0
|
|
8
|
+
Requires-Dist: torch-geometric>=2.3.0
|
|
9
|
+
Requires-Dist: pygod>=1.1.0
|
|
10
|
+
Requires-Dist: numpy>=1.24.0
|
|
11
|
+
Requires-Dist: scikit-learn>=1.3.0
|
|
12
|
+
Dynamic: author
|
|
13
|
+
Dynamic: requires-dist
|
|
14
|
+
Dynamic: requires-python
|
|
15
|
+
Dynamic: summary
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
__init__.py
|
|
2
|
+
detectors.py
|
|
3
|
+
setup.py
|
|
4
|
+
train.py
|
|
5
|
+
utils.py
|
|
6
|
+
./__init__.py
|
|
7
|
+
./detectors.py
|
|
8
|
+
./train.py
|
|
9
|
+
./utils.py
|
|
10
|
+
graflag_bond.egg-info/PKG-INFO
|
|
11
|
+
graflag_bond.egg-info/SOURCES.txt
|
|
12
|
+
graflag_bond.egg-info/dependency_links.txt
|
|
13
|
+
graflag_bond.egg-info/requires.txt
|
|
14
|
+
graflag_bond.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
graflag_bond
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Setup script for graflag_bond package."""
|
|
2
|
+
|
|
3
|
+
from setuptools import setup
|
|
4
|
+
|
|
5
|
+
setup(
|
|
6
|
+
name="graflag_bond",
|
|
7
|
+
version="1.0.0",
|
|
8
|
+
description="Universal PyGOD detector wrapper for GraFlag BOND methods",
|
|
9
|
+
author="GraFlag Team",
|
|
10
|
+
packages=["graflag_bond"],
|
|
11
|
+
package_dir={"graflag_bond": "."},
|
|
12
|
+
install_requires=[
|
|
13
|
+
"torch>=2.0.0",
|
|
14
|
+
"torch-geometric>=2.3.0",
|
|
15
|
+
"pygod>=1.1.0",
|
|
16
|
+
"numpy>=1.24.0",
|
|
17
|
+
"scikit-learn>=1.3.0",
|
|
18
|
+
],
|
|
19
|
+
python_requires=">=3.7",
|
|
20
|
+
)
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Generic PyGOD Bond Training Script
|
|
4
|
+
|
|
5
|
+
This script trains any PyGOD detector based on METHOD_NAME environment variable.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
import time
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
import psutil
|
|
14
|
+
import torch
|
|
15
|
+
|
|
16
|
+
# Import graflag_runner utilities
|
|
17
|
+
from graflag_runner import ResultWriter
|
|
18
|
+
from graflag_runner import info, warning, error
|
|
19
|
+
|
|
20
|
+
# Import PyGOD
|
|
21
|
+
from pygod.utils import load_data
|
|
22
|
+
|
|
23
|
+
# Import bond utilities
|
|
24
|
+
from graflag_bond.detectors import BondDetector
|
|
25
|
+
from graflag_bond.utils import get_all_parameters
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def load_graph_data(data_dir):
|
|
29
|
+
"""Load graph data from PyGOD datasets."""
|
|
30
|
+
|
|
31
|
+
supported_data = os.environ.get("SUPPORTED_DATA", "").split(", ")
|
|
32
|
+
dataset_name = data_dir.name
|
|
33
|
+
|
|
34
|
+
if supported_data and dataset_name not in supported_data:
|
|
35
|
+
warning(f"Dataset '{dataset_name}' may not be officially tested. Supported: {supported_data}")
|
|
36
|
+
|
|
37
|
+
info(f"Loading dataset: {dataset_name} from {data_dir}")
|
|
38
|
+
|
|
39
|
+
# Load data using PyGOD's load_data
|
|
40
|
+
data = load_data(dataset_name, cache_dir=data_dir)
|
|
41
|
+
info(f"Graph: {data.num_nodes} nodes, {data.num_edges} edges, {data.num_features} features")
|
|
42
|
+
|
|
43
|
+
return data
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def train_detector(method_name, data, exp_dir, writer):
|
|
47
|
+
"""Train PyGOD detector."""
|
|
48
|
+
|
|
49
|
+
# Get detector name and class dynamically
|
|
50
|
+
detector_name = BondDetector.from_method_name(method_name)
|
|
51
|
+
detector_class = BondDetector.get_detector_class(detector_name)
|
|
52
|
+
|
|
53
|
+
# Get parameters from environment with type hints from detector signature
|
|
54
|
+
params = get_all_parameters(detector_class)
|
|
55
|
+
|
|
56
|
+
info("=" * 60)
|
|
57
|
+
info(f"Training {detector_name.upper()} Model")
|
|
58
|
+
info("=" * 60)
|
|
59
|
+
|
|
60
|
+
# Log key parameters
|
|
61
|
+
info(f"Detector: {detector_class.__name__}")
|
|
62
|
+
if "hid_dim" in params:
|
|
63
|
+
info(f"Architecture: hid_dim={params['hid_dim']}, num_layers={params.get('num_layers', 'N/A')}")
|
|
64
|
+
if "epoch" in params:
|
|
65
|
+
info(f"Training: epochs={params['epoch']}, lr={params.get('lr', 'N/A')}")
|
|
66
|
+
if "contamination" in params:
|
|
67
|
+
info(f"Contamination: {params['contamination']}")
|
|
68
|
+
|
|
69
|
+
# Initialize model
|
|
70
|
+
info(f"Initializing {detector_name.upper()} detector...")
|
|
71
|
+
model = detector_class(**params)
|
|
72
|
+
|
|
73
|
+
# Train model
|
|
74
|
+
info("Starting training...")
|
|
75
|
+
start_time = time.time()
|
|
76
|
+
model.fit(data)
|
|
77
|
+
training_time = time.time() - start_time
|
|
78
|
+
|
|
79
|
+
# Log training metrics
|
|
80
|
+
writer.spot("training",
|
|
81
|
+
epochs=params.get('epoch', 'N/A'),
|
|
82
|
+
training_time_sec=training_time)
|
|
83
|
+
|
|
84
|
+
info(f"Training completed in {training_time:.2f}s")
|
|
85
|
+
|
|
86
|
+
return model
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def save_results(model, data, exp_dir, writer, method_name, dataset_name,
|
|
90
|
+
exec_time_ms, peak_memory_mb, peak_gpu_mb=None):
|
|
91
|
+
"""Save results with metadata and resource metrics."""
|
|
92
|
+
info("=" * 60)
|
|
93
|
+
info("Generating Results")
|
|
94
|
+
info("=" * 60)
|
|
95
|
+
|
|
96
|
+
# Get anomaly scores
|
|
97
|
+
scores = model.decision_score_
|
|
98
|
+
|
|
99
|
+
# Get ground truth labels from data (binarize: 0=normal, any non-zero=anomaly)
|
|
100
|
+
gt_raw = data.y.cpu() if hasattr(data.y, 'cpu') else data.y
|
|
101
|
+
ground_truth = [1 if label != 0 else 0 for label in gt_raw]
|
|
102
|
+
|
|
103
|
+
# Save results using ResultWriter
|
|
104
|
+
writer.save_scores(
|
|
105
|
+
result_type="NODE_ANOMALY_SCORES",
|
|
106
|
+
scores=scores.tolist(),
|
|
107
|
+
ground_truth=ground_truth,
|
|
108
|
+
node_ids=list(range(len(scores)))
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Get detector info
|
|
112
|
+
detector_name = BondDetector.from_method_name(method_name)
|
|
113
|
+
detector_class = BondDetector.get_detector_class(detector_name)
|
|
114
|
+
params = get_all_parameters(detector_class)
|
|
115
|
+
|
|
116
|
+
# Convert params to JSON-safe strings (some values are Python types/functions)
|
|
117
|
+
safe_params = {}
|
|
118
|
+
for k, v in params.items():
|
|
119
|
+
if callable(v) or isinstance(v, type):
|
|
120
|
+
safe_params[k] = f"{v.__module__}.{v.__qualname__}" if hasattr(v, '__module__') else str(v)
|
|
121
|
+
else:
|
|
122
|
+
safe_params[k] = v
|
|
123
|
+
|
|
124
|
+
# Add metadata
|
|
125
|
+
writer.add_metadata(
|
|
126
|
+
exp_name=os.path.basename(os.environ.get("EXP", "experiment")),
|
|
127
|
+
method_name=method_name,
|
|
128
|
+
dataset=dataset_name,
|
|
129
|
+
method_parameters=safe_params,
|
|
130
|
+
threshold=None,
|
|
131
|
+
summary={
|
|
132
|
+
"description": f"PyGOD {detector_name.upper()} detector",
|
|
133
|
+
"task": "node_anomaly_detection",
|
|
134
|
+
"dataset_info": {
|
|
135
|
+
"name": dataset_name,
|
|
136
|
+
"num_nodes": data.num_nodes,
|
|
137
|
+
"num_edges": data.num_edges,
|
|
138
|
+
"num_features": data.num_features,
|
|
139
|
+
"num_anomalies": sum(ground_truth),
|
|
140
|
+
},
|
|
141
|
+
},
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# Add resource metrics
|
|
145
|
+
writer.add_resource_metrics(
|
|
146
|
+
exec_time_ms=exec_time_ms,
|
|
147
|
+
peak_memory_mb=peak_memory_mb,
|
|
148
|
+
peak_gpu_mb=peak_gpu_mb,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Finalize results
|
|
152
|
+
writer.finalize()
|
|
153
|
+
|
|
154
|
+
info(f"Results saved to {exp_dir}")
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def main():
|
|
158
|
+
# Get environment variables
|
|
159
|
+
method_name = os.environ.get("METHOD_NAME")
|
|
160
|
+
if not method_name:
|
|
161
|
+
error("METHOD_NAME environment variable not set!")
|
|
162
|
+
sys.exit(1)
|
|
163
|
+
|
|
164
|
+
data_dir = Path(os.environ.get("DATA"))
|
|
165
|
+
exp_dir = Path(os.environ.get("EXP"))
|
|
166
|
+
|
|
167
|
+
info("=" * 60)
|
|
168
|
+
info(f"PyGOD Bond: {method_name.upper()}")
|
|
169
|
+
info("=" * 60)
|
|
170
|
+
info(f"Dataset: {data_dir}")
|
|
171
|
+
info(f"Output: {exp_dir}")
|
|
172
|
+
info("")
|
|
173
|
+
|
|
174
|
+
# Create experiment directory
|
|
175
|
+
exp_dir.mkdir(parents=True, exist_ok=True)
|
|
176
|
+
|
|
177
|
+
# Start resource tracking
|
|
178
|
+
start_time = time.time()
|
|
179
|
+
process = psutil.Process()
|
|
180
|
+
peak_memory_mb = 0.0
|
|
181
|
+
|
|
182
|
+
# Initialize ResultWriter
|
|
183
|
+
writer = ResultWriter()
|
|
184
|
+
|
|
185
|
+
try:
|
|
186
|
+
# Load data
|
|
187
|
+
data = load_graph_data(data_dir)
|
|
188
|
+
|
|
189
|
+
# Track memory
|
|
190
|
+
peak_memory_mb = max(peak_memory_mb, process.memory_info().rss / (1024 * 1024))
|
|
191
|
+
|
|
192
|
+
# Train model
|
|
193
|
+
model = train_detector(method_name, data, exp_dir, writer)
|
|
194
|
+
|
|
195
|
+
# Track memory after training
|
|
196
|
+
peak_memory_mb = max(peak_memory_mb, process.memory_info().rss / (1024 * 1024))
|
|
197
|
+
|
|
198
|
+
# Calculate resource metrics
|
|
199
|
+
end_time = time.time()
|
|
200
|
+
exec_time_ms = (end_time - start_time) * 1000
|
|
201
|
+
|
|
202
|
+
# Track GPU memory if available
|
|
203
|
+
peak_gpu_mb = None
|
|
204
|
+
if torch.cuda.is_available():
|
|
205
|
+
gpu_bytes = torch.cuda.max_memory_allocated()
|
|
206
|
+
if gpu_bytes > 0:
|
|
207
|
+
peak_gpu_mb = gpu_bytes / (1024 * 1024)
|
|
208
|
+
|
|
209
|
+
# Save results with metadata and resource metrics
|
|
210
|
+
save_results(model, data, exp_dir, writer, method_name, data_dir.name,
|
|
211
|
+
exec_time_ms, peak_memory_mb, peak_gpu_mb)
|
|
212
|
+
|
|
213
|
+
info("")
|
|
214
|
+
info(f"[INFO] Resource Usage:")
|
|
215
|
+
info(f" [INFO] Execution time: {exec_time_ms/1000:.2f}s")
|
|
216
|
+
info(f" [INFO] Peak memory: {peak_memory_mb:.2f}MB")
|
|
217
|
+
if peak_gpu_mb is not None:
|
|
218
|
+
info(f" [INFO] Peak GPU memory: {peak_gpu_mb:.2f}MB")
|
|
219
|
+
|
|
220
|
+
info("")
|
|
221
|
+
info("=" * 60)
|
|
222
|
+
info(f"{method_name.upper()} execution completed successfully!")
|
|
223
|
+
info("=" * 60)
|
|
224
|
+
|
|
225
|
+
except Exception as e:
|
|
226
|
+
error(f"Error during execution: {e}")
|
|
227
|
+
import traceback
|
|
228
|
+
traceback.print_exc()
|
|
229
|
+
sys.exit(1)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
if __name__ == "__main__":
|
|
233
|
+
main()
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility functions for graflag_bond.
|
|
3
|
+
|
|
4
|
+
Dynamically handles parameter extraction from environment variables.
|
|
5
|
+
Converts values to appropriate Python types based on parameter names and values.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
from typing import Dict, Any
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def str_to_bool(value: str) -> bool:
|
|
14
|
+
"""Convert string to boolean."""
|
|
15
|
+
return value.lower() in ('true', '1', 'yes')
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_activation_function(activation_value: str):
|
|
19
|
+
"""
|
|
20
|
+
Convert activation function path/name to PyTorch activation function.
|
|
21
|
+
Handles any torch.nn.functional activation function dynamically.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
activation_value: Activation function path (e.g., 'torch.nn.functional.relu')
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
PyTorch activation function
|
|
28
|
+
"""
|
|
29
|
+
# Extract function name from full path
|
|
30
|
+
if 'torch.nn.functional.' in activation_value:
|
|
31
|
+
func_name = activation_value.split('.')[-1]
|
|
32
|
+
else:
|
|
33
|
+
func_name = activation_value
|
|
34
|
+
|
|
35
|
+
# Get the function from torch.nn.functional
|
|
36
|
+
if hasattr(F, func_name):
|
|
37
|
+
return getattr(F, func_name)
|
|
38
|
+
else:
|
|
39
|
+
# Default to relu if function not found
|
|
40
|
+
return F.relu
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_backbone_class(backbone_value: str):
|
|
44
|
+
"""
|
|
45
|
+
Convert backbone path to PyTorch Geometric class.
|
|
46
|
+
Handles any torch_geometric.nn class dynamically.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
backbone_value: Backbone class path (e.g., 'torch_geometric.nn.GCN')
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
PyTorch Geometric class or None
|
|
53
|
+
"""
|
|
54
|
+
if backbone_value.lower() == 'none':
|
|
55
|
+
return None
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
# Extract class name from full path
|
|
59
|
+
if 'torch_geometric.nn.' in backbone_value:
|
|
60
|
+
class_name = backbone_value.split('.')[-1]
|
|
61
|
+
else:
|
|
62
|
+
class_name = backbone_value
|
|
63
|
+
|
|
64
|
+
# Import torch_geometric.nn
|
|
65
|
+
import torch_geometric.nn as pyg_nn
|
|
66
|
+
|
|
67
|
+
# Get the class dynamically
|
|
68
|
+
if hasattr(pyg_nn, class_name):
|
|
69
|
+
return getattr(pyg_nn, class_name)
|
|
70
|
+
else:
|
|
71
|
+
return None
|
|
72
|
+
except ImportError:
|
|
73
|
+
return None
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def convert_env_value(env_name: str, env_value: str, expected_type: type = None) -> Any:
|
|
77
|
+
"""
|
|
78
|
+
Convert environment variable value to appropriate Python type.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
env_name: Name of environment variable (uppercase)
|
|
82
|
+
env_value: String value from environment
|
|
83
|
+
expected_type: Expected type from function signature (if available)
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Converted value with appropriate type
|
|
87
|
+
"""
|
|
88
|
+
# Handle activation functions (callable)
|
|
89
|
+
if 'torch.nn.functional' in env_value:
|
|
90
|
+
return get_activation_function(env_value)
|
|
91
|
+
|
|
92
|
+
# Handle backbone classes (torch.nn.Module)
|
|
93
|
+
if 'torch_geometric.nn' in env_value:
|
|
94
|
+
return get_backbone_class(env_value)
|
|
95
|
+
|
|
96
|
+
# Handle None
|
|
97
|
+
if env_value.lower() == 'none':
|
|
98
|
+
return None
|
|
99
|
+
|
|
100
|
+
# Handle boolean values
|
|
101
|
+
if env_value.lower() in ['true', 'false']:
|
|
102
|
+
return str_to_bool(env_value)
|
|
103
|
+
|
|
104
|
+
# If we have expected type from signature, use it
|
|
105
|
+
if expected_type is not None:
|
|
106
|
+
try:
|
|
107
|
+
if expected_type == float:
|
|
108
|
+
return float(env_value)
|
|
109
|
+
elif expected_type == int:
|
|
110
|
+
return int(env_value)
|
|
111
|
+
elif expected_type == bool:
|
|
112
|
+
return str_to_bool(env_value)
|
|
113
|
+
elif expected_type == str:
|
|
114
|
+
return env_value
|
|
115
|
+
except (ValueError, TypeError):
|
|
116
|
+
pass
|
|
117
|
+
|
|
118
|
+
# Fallback: Try to detect type from value
|
|
119
|
+
try:
|
|
120
|
+
# Try int first (if no decimal point)
|
|
121
|
+
if '.' not in env_value:
|
|
122
|
+
return int(env_value)
|
|
123
|
+
|
|
124
|
+
# Has decimal point, convert to float
|
|
125
|
+
return float(env_value)
|
|
126
|
+
except (ValueError, AttributeError):
|
|
127
|
+
pass
|
|
128
|
+
|
|
129
|
+
# Return as string if conversion fails
|
|
130
|
+
return env_value
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def get_all_parameters(detector_class=None) -> Dict[str, Any]:
|
|
134
|
+
"""
|
|
135
|
+
Get all parameters from environment variables.
|
|
136
|
+
Only reads environment variables prefixed with underscore (_PARAM_NAME).
|
|
137
|
+
Automatically converts parameter names from _UPPER_CASE to lower_case
|
|
138
|
+
and values to appropriate Python types based on detector signature.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
detector_class: Optional detector class to inspect for parameter types
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Dictionary of all parameters with correct types
|
|
145
|
+
"""
|
|
146
|
+
import inspect
|
|
147
|
+
|
|
148
|
+
params = {}
|
|
149
|
+
|
|
150
|
+
# Get parameter types from detector signature if available
|
|
151
|
+
param_types = {}
|
|
152
|
+
if detector_class is not None:
|
|
153
|
+
try:
|
|
154
|
+
sig = inspect.signature(detector_class.__init__)
|
|
155
|
+
for param_name, param in sig.parameters.items():
|
|
156
|
+
if param_name in ['self', 'args', 'kwargs']:
|
|
157
|
+
continue
|
|
158
|
+
|
|
159
|
+
# First try to get type from annotation
|
|
160
|
+
if param.annotation != inspect.Parameter.empty:
|
|
161
|
+
param_types[param_name] = param.annotation
|
|
162
|
+
# If no annotation, get type from default value
|
|
163
|
+
elif param.default != inspect.Parameter.empty and param.default is not None:
|
|
164
|
+
param_types[param_name] = type(param.default)
|
|
165
|
+
except (ValueError, TypeError):
|
|
166
|
+
pass
|
|
167
|
+
|
|
168
|
+
# Iterate through all environment variables
|
|
169
|
+
for env_name, env_value in os.environ.items():
|
|
170
|
+
# Only process variables that start with underscore
|
|
171
|
+
if not env_name.startswith('_'):
|
|
172
|
+
continue
|
|
173
|
+
|
|
174
|
+
# Remove underscore prefix and convert to lowercase
|
|
175
|
+
param_name = env_name[1:].lower()
|
|
176
|
+
|
|
177
|
+
# Get expected type from signature
|
|
178
|
+
expected_type = param_types.get(param_name)
|
|
179
|
+
|
|
180
|
+
# Convert value to appropriate type
|
|
181
|
+
param_value = convert_env_value(env_name, env_value, expected_type)
|
|
182
|
+
|
|
183
|
+
# Add to parameters
|
|
184
|
+
params[param_name] = param_value
|
|
185
|
+
|
|
186
|
+
return params
|