gridfm-graphkit 0.0.4__tar.gz → 0.0.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/PKG-INFO +24 -20
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/README.md +11 -6
- gridfm_graphkit-0.0.6/gridfm_graphkit/datasets/postprocessing.py +83 -0
- gridfm_graphkit-0.0.6/gridfm_graphkit/utils/utils.py +42 -0
- gridfm_graphkit-0.0.6/gridfm_graphkit/utils/visualization.py +513 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit.egg-info/PKG-INFO +24 -20
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit.egg-info/SOURCES.txt +2 -0
- gridfm_graphkit-0.0.6/gridfm_graphkit.egg-info/requires.txt +22 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/pyproject.toml +13 -14
- gridfm_graphkit-0.0.4/gridfm_graphkit/utils/visualization.py +0 -99
- gridfm_graphkit-0.0.4/gridfm_graphkit.egg-info/requires.txt +0 -23
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/LICENSE +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit/__init__.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit/__main__.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit/cli.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit/datasets/__init__.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit/datasets/globals.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit/datasets/normalizers.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit/datasets/powergrid_datamodule.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit/datasets/powergrid_dataset.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit/datasets/transforms.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit/datasets/utils.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit/io/__init__.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit/io/param_handler.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit/io/registries.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit/models/__init__.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit/models/gnn_transformer.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit/models/gps_transformer.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit/tasks/__init__.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit/tasks/feature_reconstruction_task.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit/training/__init__.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit/training/callbacks.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit/training/loss.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit/utils/__init__.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit.egg-info/dependency_links.txt +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit.egg-info/entry_points.txt +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit.egg-info/top_level.txt +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/setup.cfg +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/tests/test_data_module.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/tests/test_full_pipeline.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/tests/test_losses.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/tests/test_model_outputs.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/tests/test_normalization.py +0 -0
- {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/tests/test_yaml_configs.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gridfm-graphkit
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.6
|
|
4
4
|
Summary: Grid Foundation Model
|
|
5
5
|
Author-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>, Alban Puech <apuech@seas.harvard.edu>, Tamara Govindasamy <tamara.govindasamy@ibm.com>, Mangaliso Mngomezulu <mngomezulum@ibm.com>, Etienne Vos <etienne.vos@ibm.com>, Celia Cintas <celia.cintas@ibm.com>, Jonas Weiss <jwe@zurich.ibm.com>
|
|
6
6
|
Maintainer-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>
|
|
@@ -14,30 +14,39 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
14
14
|
Requires-Python: <3.13,>=3.10
|
|
15
15
|
Description-Content-Type: text/markdown
|
|
16
16
|
License-File: LICENSE
|
|
17
|
-
Requires-Dist:
|
|
18
|
-
Requires-Dist:
|
|
19
|
-
Requires-Dist:
|
|
20
|
-
Requires-Dist:
|
|
21
|
-
Requires-Dist:
|
|
22
|
-
Requires-Dist:
|
|
23
|
-
Requires-Dist:
|
|
24
|
-
Requires-Dist:
|
|
25
|
-
Requires-Dist:
|
|
26
|
-
Requires-Dist: torchaudio>=2.7.1
|
|
27
|
-
Requires-Dist: torchvision>=0.22.1
|
|
17
|
+
Requires-Dist: torch>2.0
|
|
18
|
+
Requires-Dist: torch-geometric
|
|
19
|
+
Requires-Dist: mlflow
|
|
20
|
+
Requires-Dist: nbformat
|
|
21
|
+
Requires-Dist: networkx
|
|
22
|
+
Requires-Dist: numpy
|
|
23
|
+
Requires-Dist: pandas
|
|
24
|
+
Requires-Dist: plotly
|
|
25
|
+
Requires-Dist: pyyaml
|
|
28
26
|
Requires-Dist: lightning
|
|
27
|
+
Requires-Dist: seaborn
|
|
29
28
|
Provides-Extra: dev
|
|
30
29
|
Requires-Dist: mkdocs-material; extra == "dev"
|
|
31
30
|
Requires-Dist: mkdocstrings[python]; extra == "dev"
|
|
32
|
-
Requires-Dist: pre-commit
|
|
33
|
-
Requires-Dist: bandit
|
|
31
|
+
Requires-Dist: pre-commit; extra == "dev"
|
|
32
|
+
Requires-Dist: bandit; extra == "dev"
|
|
34
33
|
Requires-Dist: build; extra == "dev"
|
|
35
34
|
Provides-Extra: test
|
|
36
35
|
Requires-Dist: pytest; extra == "test"
|
|
37
36
|
Requires-Dist: pytest-cov; extra == "test"
|
|
38
37
|
Dynamic: license-file
|
|
39
38
|
|
|
40
|
-
|
|
39
|
+
<p align="center">
|
|
40
|
+
<img src="https://raw.githubusercontent.com/gridfm/gridfm-graphkit/refs/heads/main/docs/figs/KIT.png" alt="GridFM logo" style="width: 40%; height: auto;"/>
|
|
41
|
+
<br/>
|
|
42
|
+
</p>
|
|
43
|
+
|
|
44
|
+
<p align="center" style="font-size: 25px;">
|
|
45
|
+
<b>gridfm-graphkit</b>
|
|
46
|
+
</p>
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
[](https://doi.org/10.5281/zenodo.17016737)
|
|
41
50
|
[](https://gridfm.github.io/gridfm-graphkit/)
|
|
42
51
|

|
|
43
52
|

|
|
@@ -47,11 +56,6 @@ This library is brought to you by the GridFM team to train, finetune and interac
|
|
|
47
56
|
|
|
48
57
|
---
|
|
49
58
|
|
|
50
|
-
<p align="center">
|
|
51
|
-
<img src="https://raw.githubusercontent.com/gridfm/gridfm-graphkit/refs/heads/main/docs/figs/pre_training.png" alt="GridFM logo"/>
|
|
52
|
-
<br/>
|
|
53
|
-
</p>
|
|
54
|
-
|
|
55
59
|
# Installation
|
|
56
60
|
|
|
57
61
|
You can install `gridfm-graphkit` directly from PyPI:
|
|
@@ -1,4 +1,14 @@
|
|
|
1
|
-
|
|
1
|
+
<p align="center">
|
|
2
|
+
<img src="https://raw.githubusercontent.com/gridfm/gridfm-graphkit/refs/heads/main/docs/figs/KIT.png" alt="GridFM logo" style="width: 40%; height: auto;"/>
|
|
3
|
+
<br/>
|
|
4
|
+
</p>
|
|
5
|
+
|
|
6
|
+
<p align="center" style="font-size: 25px;">
|
|
7
|
+
<b>gridfm-graphkit</b>
|
|
8
|
+
</p>
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
[](https://doi.org/10.5281/zenodo.17016737)
|
|
2
12
|
[](https://gridfm.github.io/gridfm-graphkit/)
|
|
3
13
|

|
|
4
14
|

|
|
@@ -8,11 +18,6 @@ This library is brought to you by the GridFM team to train, finetune and interac
|
|
|
8
18
|
|
|
9
19
|
---
|
|
10
20
|
|
|
11
|
-
<p align="center">
|
|
12
|
-
<img src="https://raw.githubusercontent.com/gridfm/gridfm-graphkit/refs/heads/main/docs/figs/pre_training.png" alt="GridFM logo"/>
|
|
13
|
-
<br/>
|
|
14
|
-
</p>
|
|
15
|
-
|
|
16
21
|
# Installation
|
|
17
22
|
|
|
18
23
|
You can install `gridfm-graphkit` directly from PyPI:
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from scipy.sparse import csr_matrix
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def compute_branch_currents_kA(Yf, Yt, V, Vf_base_kV, Vt_base_kV, sn_mva):
|
|
6
|
+
"""
|
|
7
|
+
TODO docstrings
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
If_pu = Yf @ V # From-end currents in per-unit (I_f = Y_f V)
|
|
11
|
+
If_kA = np.abs(If_pu) * sn_mva / (np.sqrt(3) * Vf_base_kV) # Conversion to kA
|
|
12
|
+
|
|
13
|
+
# Construct to-end admittance matrix Yt:
|
|
14
|
+
# Yt[b, :] = y_tf_b * e_f + y_tt_b * e_t
|
|
15
|
+
It_pu = Yt @ V # To-end currents in per-unit (I_t = Y_t V)
|
|
16
|
+
It_kA = np.abs(It_pu) * sn_mva / (np.sqrt(3) * Vt_base_kV) # Conversion to kA
|
|
17
|
+
|
|
18
|
+
return If_kA, It_kA
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def compute_loading(If_kA, It_kA, Vf_base_kV, Vt_base_kV, rate_a):
|
|
22
|
+
"""
|
|
23
|
+
Compute per-branch loading using current magnitudes and branch ratings.
|
|
24
|
+
|
|
25
|
+
Parameters:
|
|
26
|
+
- edge_index: np.ndarray of shape (n_edges, 2), each row is [from_bus, to_bus]
|
|
27
|
+
- If_kA: np.ndarray of from-side current magnitudes in kA
|
|
28
|
+
- It_kA: np.ndarray of to-side current magnitudes in kA
|
|
29
|
+
- base_kv: np.ndarray of shape (n_buses,), base voltage in kV per bus
|
|
30
|
+
- edge_attr: np.ndarray of shape (n_edges, >=5), edge features, column 4 = RATE_A
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
- loading: np.ndarray of shape (n_edges,), max of from and to side loading
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
limitf = rate_a / (Vf_base_kV * np.sqrt(3))
|
|
37
|
+
limitt = rate_a / (Vt_base_kV * np.sqrt(3))
|
|
38
|
+
|
|
39
|
+
loadingf = If_kA / limitf
|
|
40
|
+
loadingt = It_kA / limitt
|
|
41
|
+
|
|
42
|
+
return np.maximum(loadingf, loadingt)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def create_admittance_matrix(bus_params, edge_params, sn_mva=100):
|
|
46
|
+
"""
|
|
47
|
+
TODO Docstrings
|
|
48
|
+
|
|
49
|
+
Parameters:
|
|
50
|
+
- bus_params: pandas df
|
|
51
|
+
- edge_params: pandas df
|
|
52
|
+
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
base_kv = bus_params["baseKV"].values
|
|
56
|
+
|
|
57
|
+
# Extract from-bus and to-bus indices for each branch
|
|
58
|
+
|
|
59
|
+
f = edge_params["from_bus"].values.astype(np.int32)
|
|
60
|
+
t = edge_params["to_bus"].values.astype(np.int32)
|
|
61
|
+
|
|
62
|
+
# Extract branch admittance coefficients
|
|
63
|
+
Yff = edge_params["Yff_r"].values + 1j * edge_params["Yff_i"].values
|
|
64
|
+
Yft = edge_params["Yft_r"].values + 1j * edge_params["Yft_i"].values
|
|
65
|
+
Ytf = edge_params["Ytf_r"].values + 1j * edge_params["Ytf_i"].values
|
|
66
|
+
Ytt = edge_params["Ytt_r"].values + 1j * edge_params["Ytt_i"].values
|
|
67
|
+
|
|
68
|
+
# Get base voltages for the from and to buses (for kA conversion)
|
|
69
|
+
Vf_base_kV = base_kv[f]
|
|
70
|
+
Vt_base_kV = base_kv[t]
|
|
71
|
+
|
|
72
|
+
nl = edge_params.shape[0]
|
|
73
|
+
nb = bus_params.shape[0]
|
|
74
|
+
|
|
75
|
+
# i = [0, 1, ..., nl-1, 0, 1, ..., nl-1], used for constructing Yf and Yt
|
|
76
|
+
i = np.hstack([np.arange(nl), np.arange(nl)])
|
|
77
|
+
|
|
78
|
+
# Construct from-end admittance matrix Yf using the linear combination:
|
|
79
|
+
# Yf[b, :] = y_ff_b * e_f + y_ft_b * e_t
|
|
80
|
+
Yf = csr_matrix((np.hstack([Yff, Yft]), (i, np.hstack([f, t]))), shape=(nl, nb))
|
|
81
|
+
Yt = csr_matrix((np.hstack([Ytf, Ytt]), (i, np.hstack([f, t]))), shape=(nl, nb))
|
|
82
|
+
|
|
83
|
+
return Yf, Yt, Vf_base_kV, Vt_base_kV
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
def compute_cm_metrics(y_test, y_pred, model_name, label_plot):
|
|
2
|
+
"""
|
|
3
|
+
Compute confusion matrix (TP,FP,TN,FN) for predicted overleads along with their respective rates and accuracy metric.
|
|
4
|
+
|
|
5
|
+
Parameters:
|
|
6
|
+
- y_pred: predicted overlads
|
|
7
|
+
- y_test: ground truth overloads
|
|
8
|
+
- prediction_dir:
|
|
9
|
+
- label_plot:
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
TP = (y_test & y_pred).sum()
|
|
13
|
+
FP = ((~y_test) & y_pred).sum()
|
|
14
|
+
TN = ((~y_test) & (~y_pred)).sum()
|
|
15
|
+
FN = (y_test & (~y_pred)).sum()
|
|
16
|
+
|
|
17
|
+
# accuracy
|
|
18
|
+
accuracy = (TP + TN) / (TP + FP + TN + FN)
|
|
19
|
+
print(f"Accuracy: {accuracy:.3f}")
|
|
20
|
+
|
|
21
|
+
TPR = TP / (TP + FN)
|
|
22
|
+
FPR = FP / (FP + TN)
|
|
23
|
+
TNR = TN / (TN + FP)
|
|
24
|
+
FNR = FN / (FN + TP)
|
|
25
|
+
# TODO change text to fit both overloadings and voltage violations
|
|
26
|
+
print("Confusion Matrix:")
|
|
27
|
+
print(f"TP: {TP}, FP: {FP}, TN: {TN}, FN: {FN}")
|
|
28
|
+
print(
|
|
29
|
+
f"GridFM\nTPR: {TPR:.3f} (percentage of overloadings correctly predicted)\nFPR: {FPR:.3f} (percentage of non-overloadings predicted as overloadings)\nTNR: {TNR:.2f}\nFNR: {FNR:.2f}",
|
|
30
|
+
)
|
|
31
|
+
with open(f"metrics_overloading_{model_name}.txt", "w") as f:
|
|
32
|
+
f.write(f"Accuracy: {accuracy:.3f}\n")
|
|
33
|
+
f.write("Confusion Matrix:\n")
|
|
34
|
+
f.write(f"TP: {TP}, FP: {FP}, TN: {TN}, FN: {FN}\n")
|
|
35
|
+
f.write(f"{label_plot} Metrics:\n")
|
|
36
|
+
f.write(f"TPR: {TPR:.5f} (percentage of overloadings correctly predicted)\n")
|
|
37
|
+
f.write(
|
|
38
|
+
f"FPR: {FPR:.5f} (percentage of non-overloadings predicted as overloadings)\n",
|
|
39
|
+
)
|
|
40
|
+
f.write(f"TNR: {TNR:.5f}\n")
|
|
41
|
+
f.write(f"FNR: {FNR:.5f}\n")
|
|
42
|
+
return TP, FP, TN, FN
|
|
@@ -0,0 +1,513 @@
|
|
|
1
|
+
from gridfm_graphkit.training.loss import PBELoss
|
|
2
|
+
from gridfm_graphkit.datasets.globals import PQ, PV, REF
|
|
3
|
+
|
|
4
|
+
import networkx as nx
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
from matplotlib.colors import LogNorm
|
|
7
|
+
from scipy.stats import pearsonr
|
|
8
|
+
import seaborn as sns
|
|
9
|
+
import numpy as np
|
|
10
|
+
import copy
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def visualize_error(data_point, output, node_normalizer):
|
|
14
|
+
loss = PBELoss(visualization=True)
|
|
15
|
+
|
|
16
|
+
loss_dict = loss(
|
|
17
|
+
output,
|
|
18
|
+
data_point.y,
|
|
19
|
+
data_point.edge_index,
|
|
20
|
+
data_point.edge_attr,
|
|
21
|
+
data_point.mask,
|
|
22
|
+
)
|
|
23
|
+
active_loss = loss_dict["Nodal Active Power Loss in p.u."]
|
|
24
|
+
active_loss = active_loss.cpu() * node_normalizer.baseMVA
|
|
25
|
+
|
|
26
|
+
# Create a graph
|
|
27
|
+
G = nx.Graph()
|
|
28
|
+
edges = [
|
|
29
|
+
(u, v)
|
|
30
|
+
for u, v in zip(
|
|
31
|
+
data_point.edge_index[0].tolist(),
|
|
32
|
+
data_point.edge_index[1].tolist(),
|
|
33
|
+
)
|
|
34
|
+
if u != v
|
|
35
|
+
]
|
|
36
|
+
G.add_edges_from(edges)
|
|
37
|
+
|
|
38
|
+
# Assign labels based on node type
|
|
39
|
+
node_shapes = {"REF": "s", "PV": "H", "PQ": "o"}
|
|
40
|
+
num_nodes = data_point.x.shape[0]
|
|
41
|
+
mask_PQ = data_point.x[:, PQ] == 1
|
|
42
|
+
mask_PV = data_point.x[:, PV] == 1
|
|
43
|
+
mask_REF = data_point.x[:, REF] == 1
|
|
44
|
+
node_labels = {}
|
|
45
|
+
for i in range(num_nodes):
|
|
46
|
+
if mask_REF[i]:
|
|
47
|
+
node_labels[i] = "REF"
|
|
48
|
+
elif mask_PV[i]:
|
|
49
|
+
node_labels[i] = "PV"
|
|
50
|
+
elif mask_PQ[i]:
|
|
51
|
+
node_labels[i] = "PQ"
|
|
52
|
+
|
|
53
|
+
# Set node positions
|
|
54
|
+
pos = nx.spring_layout(G, seed=42)
|
|
55
|
+
|
|
56
|
+
# Define colormap
|
|
57
|
+
cmap = plt.cm.viridis
|
|
58
|
+
vmin = min(active_loss)
|
|
59
|
+
vmax = max(active_loss)
|
|
60
|
+
norm = plt.Normalize(vmin=vmin, vmax=vmax)
|
|
61
|
+
|
|
62
|
+
# Create a figure and axis
|
|
63
|
+
fig, ax = plt.subplots(figsize=(13, 7))
|
|
64
|
+
|
|
65
|
+
# Draw nodes with heatmap coloring
|
|
66
|
+
for node_type, shape in node_shapes.items():
|
|
67
|
+
nodes = [i for i in node_labels if node_labels[i] == node_type]
|
|
68
|
+
nx.draw_networkx_nodes(
|
|
69
|
+
G,
|
|
70
|
+
pos,
|
|
71
|
+
nodelist=nodes,
|
|
72
|
+
node_color=[active_loss[i] for i in nodes],
|
|
73
|
+
cmap=cmap,
|
|
74
|
+
node_size=800,
|
|
75
|
+
ax=ax,
|
|
76
|
+
vmin=vmin,
|
|
77
|
+
vmax=vmax,
|
|
78
|
+
node_shape=shape,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# Draw edges
|
|
82
|
+
nx.draw_networkx_edges(G, pos, edge_color="gray", alpha=0.5, ax=ax)
|
|
83
|
+
|
|
84
|
+
# Draw labels (node types)
|
|
85
|
+
nx.draw_networkx_labels(
|
|
86
|
+
G,
|
|
87
|
+
pos,
|
|
88
|
+
labels=node_labels,
|
|
89
|
+
font_size=10,
|
|
90
|
+
font_color="white",
|
|
91
|
+
font_weight="bold",
|
|
92
|
+
ax=ax,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# Add colorbar
|
|
96
|
+
cbar = plt.colorbar(plt.cm.ScalarMappable(cmap=cmap, norm=norm), ax=ax)
|
|
97
|
+
cbar.set_label("Active Power Residuals (MW)", fontsize=12)
|
|
98
|
+
cbar.ax.tick_params(labelsize=12)
|
|
99
|
+
|
|
100
|
+
for spine in ax.spines.values():
|
|
101
|
+
spine.set_linewidth(2) # Adjust thickness here (e.g., 2 or any value)
|
|
102
|
+
|
|
103
|
+
# Show plot
|
|
104
|
+
plt.title("Nodal Active Power Residuals", fontsize=14, fontweight="bold")
|
|
105
|
+
plt.show()
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def visualize_quantity_heatmap(
|
|
109
|
+
data_point,
|
|
110
|
+
output,
|
|
111
|
+
quantity,
|
|
112
|
+
quantity_name,
|
|
113
|
+
unit,
|
|
114
|
+
node_normalizer,
|
|
115
|
+
):
|
|
116
|
+
"""
|
|
117
|
+
Visualizes a heatmap of a specified quantity (VM, PD, QD, PG, QG, VA) for a given dataset and model.
|
|
118
|
+
|
|
119
|
+
Parameters:
|
|
120
|
+
data_point: Power grid data.
|
|
121
|
+
model: The trained model used for inference.
|
|
122
|
+
quantity: The quantity to visualize (e.g., VM, PD, QD, PG, QG, VA).
|
|
123
|
+
"""
|
|
124
|
+
data_point = copy.deepcopy(data_point)
|
|
125
|
+
output = copy.deepcopy(output)
|
|
126
|
+
mask_PQ = data_point.x[:, PQ] == 1
|
|
127
|
+
mask_PV = data_point.x[:, PV] == 1
|
|
128
|
+
mask_REF = data_point.x[:, REF] == 1
|
|
129
|
+
|
|
130
|
+
output = node_normalizer.inverse_transform(output)
|
|
131
|
+
denormalized_gt = node_normalizer.inverse_transform(data_point.y)
|
|
132
|
+
|
|
133
|
+
gt_values = denormalized_gt[:, quantity]
|
|
134
|
+
predicted_values = output[:, quantity]
|
|
135
|
+
predicted_values[~data_point.mask[:, quantity]] = denormalized_gt[
|
|
136
|
+
~data_point.mask[:, quantity],
|
|
137
|
+
quantity,
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
num_nodes = data_point.x.shape[0]
|
|
141
|
+
|
|
142
|
+
node_shapes = {"REF": "s", "PV": "H", "PQ": "o"}
|
|
143
|
+
|
|
144
|
+
# Create graph
|
|
145
|
+
G = nx.Graph()
|
|
146
|
+
edges = [
|
|
147
|
+
(u, v)
|
|
148
|
+
for u, v in zip(
|
|
149
|
+
data_point.edge_index[0].tolist(),
|
|
150
|
+
data_point.edge_index[1].tolist(),
|
|
151
|
+
)
|
|
152
|
+
if u != v
|
|
153
|
+
]
|
|
154
|
+
G.add_edges_from(edges)
|
|
155
|
+
|
|
156
|
+
node_labels = {}
|
|
157
|
+
for i in range(num_nodes):
|
|
158
|
+
if mask_REF[i]:
|
|
159
|
+
node_labels[i] = "REF"
|
|
160
|
+
elif mask_PV[i]:
|
|
161
|
+
node_labels[i] = "PV"
|
|
162
|
+
elif mask_PQ[i]:
|
|
163
|
+
node_labels[i] = "PQ"
|
|
164
|
+
|
|
165
|
+
pos = nx.spring_layout(G, seed=42)
|
|
166
|
+
cmap = plt.cm.viridis
|
|
167
|
+
vmin = min(predicted_values)
|
|
168
|
+
vmax = max(predicted_values)
|
|
169
|
+
norm = plt.Normalize(vmin=vmin, vmax=vmax)
|
|
170
|
+
|
|
171
|
+
masked_node_indices = np.where(data_point.mask[:, quantity].cpu())[0]
|
|
172
|
+
|
|
173
|
+
# Create subplots for side-by-side layout (3 plots)
|
|
174
|
+
fig, axes = plt.subplots(1, 3, figsize=(22, 8))
|
|
175
|
+
|
|
176
|
+
# First plot (ground truth values)
|
|
177
|
+
ax = axes[0]
|
|
178
|
+
for node_type, shape in node_shapes.items():
|
|
179
|
+
nodes = [i for i in node_labels if node_labels[i] == node_type]
|
|
180
|
+
node_size = 390 if node_type == "REF" else 600
|
|
181
|
+
nx.draw_networkx_nodes(
|
|
182
|
+
G,
|
|
183
|
+
pos,
|
|
184
|
+
nodelist=nodes,
|
|
185
|
+
node_color=[gt_values[i] for i in nodes],
|
|
186
|
+
cmap=cmap,
|
|
187
|
+
node_size=node_size,
|
|
188
|
+
ax=ax,
|
|
189
|
+
vmin=vmin,
|
|
190
|
+
vmax=vmax,
|
|
191
|
+
node_shape=shape,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
nx.draw_networkx_edges(G, pos, edge_color="gray", alpha=0.5, ax=ax, width=2)
|
|
195
|
+
nx.draw_networkx_labels(
|
|
196
|
+
G,
|
|
197
|
+
pos,
|
|
198
|
+
labels=node_labels,
|
|
199
|
+
font_size=10,
|
|
200
|
+
font_color="white",
|
|
201
|
+
font_weight="bold",
|
|
202
|
+
ax=ax,
|
|
203
|
+
)
|
|
204
|
+
ax.set_title(f"Ground truth {quantity_name}", fontsize=14, fontweight="bold")
|
|
205
|
+
|
|
206
|
+
for spine in ax.spines.values():
|
|
207
|
+
spine.set_linewidth(2) # Adjust thickness
|
|
208
|
+
|
|
209
|
+
# Second plot (with masked nodes in gray)
|
|
210
|
+
ax = axes[1]
|
|
211
|
+
for node_type, shape in node_shapes.items():
|
|
212
|
+
nodes = [i for i in node_labels if node_labels[i] == node_type]
|
|
213
|
+
node_size = 390 if node_type == "REF" else 600
|
|
214
|
+
nx.draw_networkx_nodes(
|
|
215
|
+
G,
|
|
216
|
+
pos,
|
|
217
|
+
nodelist=nodes,
|
|
218
|
+
node_color=[gt_values[i] for i in nodes],
|
|
219
|
+
cmap=cmap,
|
|
220
|
+
node_size=node_size,
|
|
221
|
+
ax=ax,
|
|
222
|
+
vmin=vmin,
|
|
223
|
+
vmax=vmax,
|
|
224
|
+
node_shape=shape,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
nx.draw_networkx_nodes(
|
|
228
|
+
G,
|
|
229
|
+
pos,
|
|
230
|
+
nodelist=masked_node_indices,
|
|
231
|
+
node_color="#D3D3D3",
|
|
232
|
+
node_size=750,
|
|
233
|
+
ax=ax,
|
|
234
|
+
)
|
|
235
|
+
nx.draw_networkx_edges(G, pos, edge_color="gray", alpha=0.5, ax=ax, width=2)
|
|
236
|
+
nx.draw_networkx_labels(
|
|
237
|
+
G,
|
|
238
|
+
pos,
|
|
239
|
+
labels=node_labels,
|
|
240
|
+
font_size=10,
|
|
241
|
+
font_color="white",
|
|
242
|
+
font_weight="bold",
|
|
243
|
+
ax=ax,
|
|
244
|
+
)
|
|
245
|
+
ax.set_title(f"Masked {quantity_name}", fontsize=14, fontweight="bold")
|
|
246
|
+
|
|
247
|
+
for spine in ax.spines.values():
|
|
248
|
+
spine.set_linewidth(2) # Adjust thickness
|
|
249
|
+
|
|
250
|
+
# Third plot (predicted values without masking)
|
|
251
|
+
ax = axes[2]
|
|
252
|
+
for node_type, shape in node_shapes.items():
|
|
253
|
+
nodes = [i for i in node_labels if node_labels[i] == node_type]
|
|
254
|
+
node_size = 390 if node_type == "REF" else 600
|
|
255
|
+
nx.draw_networkx_nodes(
|
|
256
|
+
G,
|
|
257
|
+
pos,
|
|
258
|
+
nodelist=nodes,
|
|
259
|
+
node_color=[predicted_values[i] for i in nodes],
|
|
260
|
+
cmap=cmap,
|
|
261
|
+
node_size=node_size,
|
|
262
|
+
ax=ax,
|
|
263
|
+
vmin=vmin,
|
|
264
|
+
vmax=vmax,
|
|
265
|
+
node_shape=shape,
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
nx.draw_networkx_edges(G, pos, edge_color="gray", alpha=0.5, ax=ax, width=2)
|
|
269
|
+
nx.draw_networkx_labels(
|
|
270
|
+
G,
|
|
271
|
+
pos,
|
|
272
|
+
labels=node_labels,
|
|
273
|
+
font_size=10,
|
|
274
|
+
font_color="white",
|
|
275
|
+
font_weight="bold",
|
|
276
|
+
ax=ax,
|
|
277
|
+
)
|
|
278
|
+
ax.set_title(f"Reconstructed {quantity_name}", fontsize=14, fontweight="bold")
|
|
279
|
+
|
|
280
|
+
for spine in ax.spines.values():
|
|
281
|
+
spine.set_linewidth(2) # Adjust thickness
|
|
282
|
+
|
|
283
|
+
# Colorbar placement
|
|
284
|
+
cbar_ax = fig.add_axes([0.93, 0.1, 0.02, 0.8])
|
|
285
|
+
cbar = plt.colorbar(plt.cm.ScalarMappable(cmap=cmap, norm=norm), cax=cbar_ax)
|
|
286
|
+
cbar.set_label(f"{quantity_name} ({unit})", fontsize=12)
|
|
287
|
+
cbar.ax.tick_params(labelsize=12)
|
|
288
|
+
|
|
289
|
+
plt.subplots_adjust(right=0.9)
|
|
290
|
+
plt.show()
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def plot_mass_correlation_density(
|
|
294
|
+
true_vals,
|
|
295
|
+
gfm_vals,
|
|
296
|
+
model_name,
|
|
297
|
+
label_plot,
|
|
298
|
+
x_max=2,
|
|
299
|
+
y_max=3,
|
|
300
|
+
):
|
|
301
|
+
"""
|
|
302
|
+
TODO docstring
|
|
303
|
+
|
|
304
|
+
"""
|
|
305
|
+
# TODO check if these parameters need to be passed by func or default behavior
|
|
306
|
+
vmin = 1
|
|
307
|
+
x_min = 0
|
|
308
|
+
y_min = 0
|
|
309
|
+
bin_width = 0.01 # consistent bin width for both plots
|
|
310
|
+
|
|
311
|
+
# Generate consistent bins
|
|
312
|
+
x_bins = np.arange(x_min, x_max + bin_width, bin_width)
|
|
313
|
+
y_bins = np.arange(y_min, y_max + bin_width, bin_width)
|
|
314
|
+
|
|
315
|
+
# estimate vmax on mean count of elements across bins
|
|
316
|
+
counts, _, _ = np.histogram2d(true_vals, gfm_vals, bins=[x_bins, y_bins])
|
|
317
|
+
|
|
318
|
+
counts[counts == 0] = np.nan
|
|
319
|
+
means = np.nanmean(counts)
|
|
320
|
+
std = np.nanstd(counts)
|
|
321
|
+
vmax = means + 3 * std
|
|
322
|
+
|
|
323
|
+
# Pearson correlations
|
|
324
|
+
corr_gfm, _ = pearsonr(true_vals, gfm_vals)
|
|
325
|
+
|
|
326
|
+
# Create figure with shared x-axis
|
|
327
|
+
fig, ax1 = plt.subplots(figsize=(8, 6), dpi=400)
|
|
328
|
+
|
|
329
|
+
# --- GridFM Mass Correlation ---
|
|
330
|
+
h1 = ax1.hist2d(
|
|
331
|
+
true_vals,
|
|
332
|
+
gfm_vals,
|
|
333
|
+
bins=[x_bins, y_bins],
|
|
334
|
+
norm=LogNorm(vmin=vmin, vmax=vmax),
|
|
335
|
+
cmap="inferno",
|
|
336
|
+
)
|
|
337
|
+
ax1.axvline(1, color="black", linestyle="--", linewidth=2.0)
|
|
338
|
+
ax1.axhline(1, color="black", linestyle="--", linewidth=2.0)
|
|
339
|
+
ax1.plot([0, 5], [0, 5], "k--", linewidth=0.5)
|
|
340
|
+
ax1.set_xlabel("True Loadings", fontsize=12)
|
|
341
|
+
ax1.set_ylabel("Predicted Loadings", fontsize=12)
|
|
342
|
+
ax1.set_title(label_plot, fontsize=14)
|
|
343
|
+
ax1.text(
|
|
344
|
+
x_max - 1.5,
|
|
345
|
+
0.93,
|
|
346
|
+
f"r = {corr_gfm:.5f}",
|
|
347
|
+
transform=ax1.transAxes,
|
|
348
|
+
fontsize=13,
|
|
349
|
+
weight="bold",
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
# Colorbar
|
|
353
|
+
cbar = fig.colorbar(h1[3], ax=ax1, pad=0.02)
|
|
354
|
+
cbar.set_label("Number of samples", fontsize=10)
|
|
355
|
+
|
|
356
|
+
# Style adjustments
|
|
357
|
+
ax1.set_xlim(x_min, x_max)
|
|
358
|
+
ax1.set_ylim(y_min, y_max)
|
|
359
|
+
ax1.grid(True, linewidth=0.3)
|
|
360
|
+
ax1.tick_params(axis="both", labelsize=10)
|
|
361
|
+
|
|
362
|
+
plt.tight_layout()
|
|
363
|
+
plt.savefig(f"mass_correlation_density_{model_name}.png", bbox_inches="tight")
|
|
364
|
+
plt.show()
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def plot_cm(TN, FP, FN, TP, model_name, label_plot):
|
|
368
|
+
"""
|
|
369
|
+
TODO docstring
|
|
370
|
+
"""
|
|
371
|
+
cm = np.array([[TN, FP], [FN, TP]])
|
|
372
|
+
|
|
373
|
+
cm_labels = ["Non-overload", "Overload"]
|
|
374
|
+
|
|
375
|
+
fig_cm, ax_cm = plt.subplots(figsize=(6, 6))
|
|
376
|
+
|
|
377
|
+
sns.heatmap(
|
|
378
|
+
cm,
|
|
379
|
+
annot=True,
|
|
380
|
+
fmt="d",
|
|
381
|
+
cbar=False,
|
|
382
|
+
square=True,
|
|
383
|
+
linewidths=0.5,
|
|
384
|
+
cmap="Blues",
|
|
385
|
+
xticklabels=cm_labels,
|
|
386
|
+
yticklabels=cm_labels,
|
|
387
|
+
ax=ax_cm,
|
|
388
|
+
annot_kws={"size": 14},
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
ax_cm.set_xlabel("Predicted", fontsize=12)
|
|
392
|
+
ax_cm.set_ylabel("True", fontsize=12)
|
|
393
|
+
ax_cm.set_title(f"Confusion Matrix {label_plot}", fontsize=14)
|
|
394
|
+
ax_cm.tick_params(axis="both", labelsize=12)
|
|
395
|
+
|
|
396
|
+
plt.tight_layout()
|
|
397
|
+
plt.savefig(f"confusion_matrix_overload_{model_name}.png", bbox_inches="tight")
|
|
398
|
+
plt.show()
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def plot_loading_predictions(
|
|
402
|
+
loadings_pred,
|
|
403
|
+
loadings_dc,
|
|
404
|
+
loadings_gt,
|
|
405
|
+
prediction_dir,
|
|
406
|
+
label_plot,
|
|
407
|
+
):
|
|
408
|
+
"""
|
|
409
|
+
TODO docstrings
|
|
410
|
+
"""
|
|
411
|
+
plt.hist(
|
|
412
|
+
loadings_pred,
|
|
413
|
+
alpha=0.5,
|
|
414
|
+
label=label_plot,
|
|
415
|
+
density=True,
|
|
416
|
+
bins=100,
|
|
417
|
+
)
|
|
418
|
+
plt.hist(loadings_dc, alpha=0.5, label="DC Solver", density=True, bins=100)
|
|
419
|
+
plt.hist(loadings_gt, alpha=0.5, label="Ground truth", density=True, bins=100)
|
|
420
|
+
|
|
421
|
+
plt.xlabel("Loading Values")
|
|
422
|
+
plt.ylabel("Density")
|
|
423
|
+
plt.yscale("log")
|
|
424
|
+
plt.legend()
|
|
425
|
+
|
|
426
|
+
plt.savefig(f"distribution_loading_predictions_{prediction_dir}.png")
|
|
427
|
+
plt.show()
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def plot_mass_correlation_density_voltage(
|
|
431
|
+
pf_node,
|
|
432
|
+
prediction_dir,
|
|
433
|
+
label_plot,
|
|
434
|
+
x_min=0.85,
|
|
435
|
+
y_min=0.85,
|
|
436
|
+
x_max=1.15,
|
|
437
|
+
y_max=1.15,
|
|
438
|
+
):
|
|
439
|
+
"""
|
|
440
|
+
TODO docstrings
|
|
441
|
+
TODO refactor if we pass by parameters a few more plot deets we can use plot_mass_correlation_density for both
|
|
442
|
+
|
|
443
|
+
"""
|
|
444
|
+
# Get the global min and max for color scaling (avoid log(0) by setting min to at least 1)
|
|
445
|
+
vmin = 1
|
|
446
|
+
bin_width = 0.001 # consistent bin width for both plots
|
|
447
|
+
|
|
448
|
+
# Generate consistent bins
|
|
449
|
+
x_bins = np.arange(x_min, x_max + bin_width, bin_width)
|
|
450
|
+
y_bins = np.arange(y_min, y_max + bin_width, bin_width)
|
|
451
|
+
|
|
452
|
+
# estimate vmax on mean count of elements across bins
|
|
453
|
+
counts, _, _ = np.histogram2d(
|
|
454
|
+
pf_node["Vm"],
|
|
455
|
+
pf_node["Vm_pred_corrected"],
|
|
456
|
+
bins=[x_bins, y_bins],
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
counts[counts == 0] = np.nan
|
|
460
|
+
means = np.nanmean(counts)
|
|
461
|
+
std = np.nanstd(counts)
|
|
462
|
+
vmax = means + 3 * std
|
|
463
|
+
|
|
464
|
+
# Pearson correlations
|
|
465
|
+
corr_vm, _ = pearsonr(pf_node["Vm"], pf_node["Vm_pred_corrected"])
|
|
466
|
+
|
|
467
|
+
# Create figure with shared x-axis
|
|
468
|
+
fig, ax1 = plt.subplots(figsize=(8, 6), dpi=400)
|
|
469
|
+
|
|
470
|
+
# --- GridFM Mass Correlation ---
|
|
471
|
+
h1 = ax1.hist2d(
|
|
472
|
+
pf_node["Vm"],
|
|
473
|
+
pf_node["Vm_pred_corrected"],
|
|
474
|
+
bins=[x_bins, y_bins],
|
|
475
|
+
norm=LogNorm(vmin=vmin, vmax=vmax),
|
|
476
|
+
cmap="inferno",
|
|
477
|
+
)
|
|
478
|
+
ax1.axvline(x_min + 0.05, color="black", linestyle="--", linewidth=2.0)
|
|
479
|
+
ax1.axhline(y_min + 0.05, color="black", linestyle="--", linewidth=2.0)
|
|
480
|
+
ax1.axvline(x_max - 0.05, color="black", linestyle="--", linewidth=2.0)
|
|
481
|
+
ax1.axhline(y_max - 0.05, color="black", linestyle="--", linewidth=2.0)
|
|
482
|
+
|
|
483
|
+
ax1.plot([0, 5], [0, 5], "k--", linewidth=0.5)
|
|
484
|
+
ax1.set_xlabel("True Voltage Magnitude", fontsize=12)
|
|
485
|
+
ax1.set_ylabel("Predicted Voltage magnitude", fontsize=12)
|
|
486
|
+
ax1.set_title(label_plot, fontsize=14)
|
|
487
|
+
ax1.text(
|
|
488
|
+
0.5,
|
|
489
|
+
0.95,
|
|
490
|
+
f"r = {corr_vm:.5f}",
|
|
491
|
+
transform=ax1.transAxes,
|
|
492
|
+
fontsize=13,
|
|
493
|
+
weight="bold",
|
|
494
|
+
ha="center",
|
|
495
|
+
va="top",
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
# Colorbar
|
|
499
|
+
cbar = fig.colorbar(h1[3], ax=ax1, pad=0.02)
|
|
500
|
+
cbar.set_label("Number of samples", fontsize=10)
|
|
501
|
+
|
|
502
|
+
# Style adjustments
|
|
503
|
+
ax1.set_xlim(x_min, x_max)
|
|
504
|
+
ax1.set_ylim(y_min, y_max)
|
|
505
|
+
ax1.grid(True, linewidth=0.3)
|
|
506
|
+
ax1.tick_params(axis="both", labelsize=10)
|
|
507
|
+
|
|
508
|
+
plt.tight_layout()
|
|
509
|
+
plt.savefig(
|
|
510
|
+
f"mass_correlation_density_voltage_{prediction_dir}.png",
|
|
511
|
+
bbox_inches="tight",
|
|
512
|
+
)
|
|
513
|
+
plt.show()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gridfm-graphkit
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.6
|
|
4
4
|
Summary: Grid Foundation Model
|
|
5
5
|
Author-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>, Alban Puech <apuech@seas.harvard.edu>, Tamara Govindasamy <tamara.govindasamy@ibm.com>, Mangaliso Mngomezulu <mngomezulum@ibm.com>, Etienne Vos <etienne.vos@ibm.com>, Celia Cintas <celia.cintas@ibm.com>, Jonas Weiss <jwe@zurich.ibm.com>
|
|
6
6
|
Maintainer-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>
|
|
@@ -14,30 +14,39 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
14
14
|
Requires-Python: <3.13,>=3.10
|
|
15
15
|
Description-Content-Type: text/markdown
|
|
16
16
|
License-File: LICENSE
|
|
17
|
-
Requires-Dist:
|
|
18
|
-
Requires-Dist:
|
|
19
|
-
Requires-Dist:
|
|
20
|
-
Requires-Dist:
|
|
21
|
-
Requires-Dist:
|
|
22
|
-
Requires-Dist:
|
|
23
|
-
Requires-Dist:
|
|
24
|
-
Requires-Dist:
|
|
25
|
-
Requires-Dist:
|
|
26
|
-
Requires-Dist: torchaudio>=2.7.1
|
|
27
|
-
Requires-Dist: torchvision>=0.22.1
|
|
17
|
+
Requires-Dist: torch>2.0
|
|
18
|
+
Requires-Dist: torch-geometric
|
|
19
|
+
Requires-Dist: mlflow
|
|
20
|
+
Requires-Dist: nbformat
|
|
21
|
+
Requires-Dist: networkx
|
|
22
|
+
Requires-Dist: numpy
|
|
23
|
+
Requires-Dist: pandas
|
|
24
|
+
Requires-Dist: plotly
|
|
25
|
+
Requires-Dist: pyyaml
|
|
28
26
|
Requires-Dist: lightning
|
|
27
|
+
Requires-Dist: seaborn
|
|
29
28
|
Provides-Extra: dev
|
|
30
29
|
Requires-Dist: mkdocs-material; extra == "dev"
|
|
31
30
|
Requires-Dist: mkdocstrings[python]; extra == "dev"
|
|
32
|
-
Requires-Dist: pre-commit
|
|
33
|
-
Requires-Dist: bandit
|
|
31
|
+
Requires-Dist: pre-commit; extra == "dev"
|
|
32
|
+
Requires-Dist: bandit; extra == "dev"
|
|
34
33
|
Requires-Dist: build; extra == "dev"
|
|
35
34
|
Provides-Extra: test
|
|
36
35
|
Requires-Dist: pytest; extra == "test"
|
|
37
36
|
Requires-Dist: pytest-cov; extra == "test"
|
|
38
37
|
Dynamic: license-file
|
|
39
38
|
|
|
40
|
-
|
|
39
|
+
<p align="center">
|
|
40
|
+
<img src="https://raw.githubusercontent.com/gridfm/gridfm-graphkit/refs/heads/main/docs/figs/KIT.png" alt="GridFM logo" style="width: 40%; height: auto;"/>
|
|
41
|
+
<br/>
|
|
42
|
+
</p>
|
|
43
|
+
|
|
44
|
+
<p align="center" style="font-size: 25px;">
|
|
45
|
+
<b>gridfm-graphkit</b>
|
|
46
|
+
</p>
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
[](https://doi.org/10.5281/zenodo.17016737)
|
|
41
50
|
[](https://gridfm.github.io/gridfm-graphkit/)
|
|
42
51
|

|
|
43
52
|

|
|
@@ -47,11 +56,6 @@ This library is brought to you by the GridFM team to train, finetune and interac
|
|
|
47
56
|
|
|
48
57
|
---
|
|
49
58
|
|
|
50
|
-
<p align="center">
|
|
51
|
-
<img src="https://raw.githubusercontent.com/gridfm/gridfm-graphkit/refs/heads/main/docs/figs/pre_training.png" alt="GridFM logo"/>
|
|
52
|
-
<br/>
|
|
53
|
-
</p>
|
|
54
|
-
|
|
55
59
|
# Installation
|
|
56
60
|
|
|
57
61
|
You can install `gridfm-graphkit` directly from PyPI:
|
|
@@ -13,6 +13,7 @@ gridfm_graphkit.egg-info/top_level.txt
|
|
|
13
13
|
gridfm_graphkit/datasets/__init__.py
|
|
14
14
|
gridfm_graphkit/datasets/globals.py
|
|
15
15
|
gridfm_graphkit/datasets/normalizers.py
|
|
16
|
+
gridfm_graphkit/datasets/postprocessing.py
|
|
16
17
|
gridfm_graphkit/datasets/powergrid_datamodule.py
|
|
17
18
|
gridfm_graphkit/datasets/powergrid_dataset.py
|
|
18
19
|
gridfm_graphkit/datasets/transforms.py
|
|
@@ -29,6 +30,7 @@ gridfm_graphkit/training/__init__.py
|
|
|
29
30
|
gridfm_graphkit/training/callbacks.py
|
|
30
31
|
gridfm_graphkit/training/loss.py
|
|
31
32
|
gridfm_graphkit/utils/__init__.py
|
|
33
|
+
gridfm_graphkit/utils/utils.py
|
|
32
34
|
gridfm_graphkit/utils/visualization.py
|
|
33
35
|
tests/test_data_module.py
|
|
34
36
|
tests/test_full_pipeline.py
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
torch>2.0
|
|
2
|
+
torch-geometric
|
|
3
|
+
mlflow
|
|
4
|
+
nbformat
|
|
5
|
+
networkx
|
|
6
|
+
numpy
|
|
7
|
+
pandas
|
|
8
|
+
plotly
|
|
9
|
+
pyyaml
|
|
10
|
+
lightning
|
|
11
|
+
seaborn
|
|
12
|
+
|
|
13
|
+
[dev]
|
|
14
|
+
mkdocs-material
|
|
15
|
+
mkdocstrings[python]
|
|
16
|
+
pre-commit
|
|
17
|
+
bandit
|
|
18
|
+
build
|
|
19
|
+
|
|
20
|
+
[test]
|
|
21
|
+
pytest
|
|
22
|
+
pytest-cov
|
|
@@ -9,7 +9,7 @@ namespaces = false
|
|
|
9
9
|
[project]
|
|
10
10
|
name = "gridfm-graphkit"
|
|
11
11
|
description = "Grid Foundation Model"
|
|
12
|
-
version = "0.0.
|
|
12
|
+
version = "0.0.6"
|
|
13
13
|
readme = "README.md"
|
|
14
14
|
license = "Apache-2.0"
|
|
15
15
|
requires-python = ">=3.10,<3.13"
|
|
@@ -40,26 +40,25 @@ classifiers = [
|
|
|
40
40
|
|
|
41
41
|
|
|
42
42
|
dependencies = [
|
|
43
|
-
"
|
|
44
|
-
"
|
|
45
|
-
"
|
|
46
|
-
"
|
|
47
|
-
"
|
|
48
|
-
"
|
|
49
|
-
"
|
|
50
|
-
"
|
|
51
|
-
"
|
|
52
|
-
"torchaudio>=2.7.1",
|
|
53
|
-
"torchvision>=0.22.1",
|
|
43
|
+
"torch>2.0",
|
|
44
|
+
"torch-geometric",
|
|
45
|
+
"mlflow",
|
|
46
|
+
"nbformat",
|
|
47
|
+
"networkx",
|
|
48
|
+
"numpy",
|
|
49
|
+
"pandas",
|
|
50
|
+
"plotly",
|
|
51
|
+
"pyyaml",
|
|
54
52
|
"lightning",
|
|
53
|
+
"seaborn",
|
|
55
54
|
]
|
|
56
55
|
|
|
57
56
|
[project.optional-dependencies]
|
|
58
57
|
dev = [
|
|
59
58
|
"mkdocs-material",
|
|
60
59
|
"mkdocstrings[python]",
|
|
61
|
-
"pre-commit
|
|
62
|
-
"bandit
|
|
60
|
+
"pre-commit",
|
|
61
|
+
"bandit",
|
|
63
62
|
"build"
|
|
64
63
|
]
|
|
65
64
|
|
|
@@ -1,99 +0,0 @@
|
|
|
1
|
-
import networkx as nx
|
|
2
|
-
from gridfm_graphkit.training.loss import PBELoss
|
|
3
|
-
from gridfm_graphkit.datasets.globals import PQ, PV, REF
|
|
4
|
-
import matplotlib.pyplot as plt
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def visualize_error(data_point, output, node_normalizer):
|
|
8
|
-
loss = PBELoss(visualization=True)
|
|
9
|
-
|
|
10
|
-
loss_dict = loss(
|
|
11
|
-
output,
|
|
12
|
-
data_point.y,
|
|
13
|
-
data_point.edge_index,
|
|
14
|
-
data_point.edge_attr,
|
|
15
|
-
data_point.mask,
|
|
16
|
-
)
|
|
17
|
-
active_loss = loss_dict["Nodal Active Power Loss in p.u."]
|
|
18
|
-
active_loss = active_loss.cpu() * node_normalizer.baseMVA
|
|
19
|
-
|
|
20
|
-
# Create a graph
|
|
21
|
-
G = nx.Graph()
|
|
22
|
-
edges = [
|
|
23
|
-
(u, v)
|
|
24
|
-
for u, v in zip(
|
|
25
|
-
data_point.edge_index[0].tolist(),
|
|
26
|
-
data_point.edge_index[1].tolist(),
|
|
27
|
-
)
|
|
28
|
-
if u != v
|
|
29
|
-
]
|
|
30
|
-
G.add_edges_from(edges)
|
|
31
|
-
|
|
32
|
-
# Assign labels based on node type
|
|
33
|
-
node_shapes = {"REF": "s", "PV": "H", "PQ": "o"}
|
|
34
|
-
num_nodes = data_point.x.shape[0]
|
|
35
|
-
mask_PQ = data_point.x[:, PQ] == 1
|
|
36
|
-
mask_PV = data_point.x[:, PV] == 1
|
|
37
|
-
mask_REF = data_point.x[:, REF] == 1
|
|
38
|
-
node_labels = {}
|
|
39
|
-
for i in range(num_nodes):
|
|
40
|
-
if mask_REF[i]:
|
|
41
|
-
node_labels[i] = "REF"
|
|
42
|
-
elif mask_PV[i]:
|
|
43
|
-
node_labels[i] = "PV"
|
|
44
|
-
elif mask_PQ[i]:
|
|
45
|
-
node_labels[i] = "PQ"
|
|
46
|
-
|
|
47
|
-
# Set node positions
|
|
48
|
-
pos = nx.spring_layout(G, seed=42)
|
|
49
|
-
|
|
50
|
-
# Define colormap
|
|
51
|
-
cmap = plt.cm.viridis
|
|
52
|
-
vmin = min(active_loss)
|
|
53
|
-
vmax = max(active_loss)
|
|
54
|
-
norm = plt.Normalize(vmin=vmin, vmax=vmax)
|
|
55
|
-
|
|
56
|
-
# Create a figure and axis
|
|
57
|
-
fig, ax = plt.subplots(figsize=(13, 7))
|
|
58
|
-
|
|
59
|
-
# Draw nodes with heatmap coloring
|
|
60
|
-
for node_type, shape in node_shapes.items():
|
|
61
|
-
nodes = [i for i in node_labels if node_labels[i] == node_type]
|
|
62
|
-
nx.draw_networkx_nodes(
|
|
63
|
-
G,
|
|
64
|
-
pos,
|
|
65
|
-
nodelist=nodes,
|
|
66
|
-
node_color=[active_loss[i] for i in nodes],
|
|
67
|
-
cmap=cmap,
|
|
68
|
-
node_size=800,
|
|
69
|
-
ax=ax,
|
|
70
|
-
vmin=vmin,
|
|
71
|
-
vmax=vmax,
|
|
72
|
-
node_shape=shape,
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
# Draw edges
|
|
76
|
-
nx.draw_networkx_edges(G, pos, edge_color="gray", alpha=0.5, ax=ax)
|
|
77
|
-
|
|
78
|
-
# Draw labels (node types)
|
|
79
|
-
nx.draw_networkx_labels(
|
|
80
|
-
G,
|
|
81
|
-
pos,
|
|
82
|
-
labels=node_labels,
|
|
83
|
-
font_size=10,
|
|
84
|
-
font_color="white",
|
|
85
|
-
font_weight="bold",
|
|
86
|
-
ax=ax,
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
# Add colorbar
|
|
90
|
-
cbar = plt.colorbar(plt.cm.ScalarMappable(cmap=cmap, norm=norm), ax=ax)
|
|
91
|
-
cbar.set_label("Active Power Residuals (MW)", fontsize=12)
|
|
92
|
-
cbar.ax.tick_params(labelsize=12)
|
|
93
|
-
|
|
94
|
-
for spine in ax.spines.values():
|
|
95
|
-
spine.set_linewidth(2) # Adjust thickness here (e.g., 2 or any value)
|
|
96
|
-
|
|
97
|
-
# Show plot
|
|
98
|
-
plt.title("Nodal Active Power Residuals", fontsize=14, fontweight="bold")
|
|
99
|
-
plt.show()
|
|
@@ -1,23 +0,0 @@
|
|
|
1
|
-
mlflow>=3.1.0
|
|
2
|
-
nbformat>=5.10.4
|
|
3
|
-
networkx>=3.4.2
|
|
4
|
-
numpy>=2.2.6
|
|
5
|
-
pandas>=2.3.0
|
|
6
|
-
plotly>=6.1.2
|
|
7
|
-
pyyaml>=6.0.2
|
|
8
|
-
torch>=2.7.1
|
|
9
|
-
torch-geometric>=2.6.1
|
|
10
|
-
torchaudio>=2.7.1
|
|
11
|
-
torchvision>=0.22.1
|
|
12
|
-
lightning
|
|
13
|
-
|
|
14
|
-
[dev]
|
|
15
|
-
mkdocs-material
|
|
16
|
-
mkdocstrings[python]
|
|
17
|
-
pre-commit>=4.2.0
|
|
18
|
-
bandit>=1.8.5
|
|
19
|
-
build
|
|
20
|
-
|
|
21
|
-
[test]
|
|
22
|
-
pytest
|
|
23
|
-
pytest-cov
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit/datasets/powergrid_datamodule.py
RENAMED
|
File without changes
|
{gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit/datasets/powergrid_dataset.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit/tasks/feature_reconstruction_task.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.6}/gridfm_graphkit.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|