gridfm-graphkit 0.0.4__tar.gz → 0.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/PKG-INFO +2 -1
  2. gridfm_graphkit-0.0.5/gridfm_graphkit/datasets/postprocessing.py +83 -0
  3. gridfm_graphkit-0.0.5/gridfm_graphkit/utils/utils.py +42 -0
  4. gridfm_graphkit-0.0.5/gridfm_graphkit/utils/visualization.py +513 -0
  5. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit.egg-info/PKG-INFO +2 -1
  6. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit.egg-info/SOURCES.txt +2 -0
  7. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit.egg-info/requires.txt +1 -0
  8. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/pyproject.toml +2 -1
  9. gridfm_graphkit-0.0.4/gridfm_graphkit/utils/visualization.py +0 -99
  10. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/LICENSE +0 -0
  11. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/README.md +0 -0
  12. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit/__init__.py +0 -0
  13. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit/__main__.py +0 -0
  14. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit/cli.py +0 -0
  15. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit/datasets/__init__.py +0 -0
  16. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit/datasets/globals.py +0 -0
  17. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit/datasets/normalizers.py +0 -0
  18. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit/datasets/powergrid_datamodule.py +0 -0
  19. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit/datasets/powergrid_dataset.py +0 -0
  20. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit/datasets/transforms.py +0 -0
  21. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit/datasets/utils.py +0 -0
  22. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit/io/__init__.py +0 -0
  23. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit/io/param_handler.py +0 -0
  24. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit/io/registries.py +0 -0
  25. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit/models/__init__.py +0 -0
  26. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit/models/gnn_transformer.py +0 -0
  27. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit/models/gps_transformer.py +0 -0
  28. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit/tasks/__init__.py +0 -0
  29. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit/tasks/feature_reconstruction_task.py +0 -0
  30. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit/training/__init__.py +0 -0
  31. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit/training/callbacks.py +0 -0
  32. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit/training/loss.py +0 -0
  33. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit/utils/__init__.py +0 -0
  34. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit.egg-info/dependency_links.txt +0 -0
  35. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit.egg-info/entry_points.txt +0 -0
  36. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/gridfm_graphkit.egg-info/top_level.txt +0 -0
  37. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/setup.cfg +0 -0
  38. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/tests/test_data_module.py +0 -0
  39. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/tests/test_full_pipeline.py +0 -0
  40. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/tests/test_losses.py +0 -0
  41. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/tests/test_model_outputs.py +0 -0
  42. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/tests/test_normalization.py +0 -0
  43. {gridfm_graphkit-0.0.4 → gridfm_graphkit-0.0.5}/tests/test_yaml_configs.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gridfm-graphkit
3
- Version: 0.0.4
3
+ Version: 0.0.5
4
4
  Summary: Grid Foundation Model
5
5
  Author-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>, Alban Puech <apuech@seas.harvard.edu>, Tamara Govindasamy <tamara.govindasamy@ibm.com>, Mangaliso Mngomezulu <mngomezulum@ibm.com>, Etienne Vos <etienne.vos@ibm.com>, Celia Cintas <celia.cintas@ibm.com>, Jonas Weiss <jwe@zurich.ibm.com>
6
6
  Maintainer-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>
@@ -26,6 +26,7 @@ Requires-Dist: torch-geometric>=2.6.1
26
26
  Requires-Dist: torchaudio>=2.7.1
27
27
  Requires-Dist: torchvision>=0.22.1
28
28
  Requires-Dist: lightning
29
+ Requires-Dist: seaborn
29
30
  Provides-Extra: dev
30
31
  Requires-Dist: mkdocs-material; extra == "dev"
31
32
  Requires-Dist: mkdocstrings[python]; extra == "dev"
@@ -0,0 +1,83 @@
1
+ import numpy as np
2
+ from scipy.sparse import csr_matrix
3
+
4
+
5
+ def compute_branch_currents_kA(Yf, Yt, V, Vf_base_kV, Vt_base_kV, sn_mva):
6
+ """
7
+ TODO docstrings
8
+ """
9
+
10
+ If_pu = Yf @ V # From-end currents in per-unit (I_f = Y_f V)
11
+ If_kA = np.abs(If_pu) * sn_mva / (np.sqrt(3) * Vf_base_kV) # Conversion to kA
12
+
13
+ # Construct to-end admittance matrix Yt:
14
+ # Yt[b, :] = y_tf_b * e_f + y_tt_b * e_t
15
+ It_pu = Yt @ V # To-end currents in per-unit (I_t = Y_t V)
16
+ It_kA = np.abs(It_pu) * sn_mva / (np.sqrt(3) * Vt_base_kV) # Conversion to kA
17
+
18
+ return If_kA, It_kA
19
+
20
+
21
+ def compute_loading(If_kA, It_kA, Vf_base_kV, Vt_base_kV, rate_a):
22
+ """
23
+ Compute per-branch loading using current magnitudes and branch ratings.
24
+
25
+ Parameters:
26
+ - edge_index: np.ndarray of shape (n_edges, 2), each row is [from_bus, to_bus]
27
+ - If_kA: np.ndarray of from-side current magnitudes in kA
28
+ - It_kA: np.ndarray of to-side current magnitudes in kA
29
+ - base_kv: np.ndarray of shape (n_buses,), base voltage in kV per bus
30
+ - edge_attr: np.ndarray of shape (n_edges, >=5), edge features, column 4 = RATE_A
31
+
32
+ Returns:
33
+ - loading: np.ndarray of shape (n_edges,), max of from and to side loading
34
+ """
35
+
36
+ limitf = rate_a / (Vf_base_kV * np.sqrt(3))
37
+ limitt = rate_a / (Vt_base_kV * np.sqrt(3))
38
+
39
+ loadingf = If_kA / limitf
40
+ loadingt = It_kA / limitt
41
+
42
+ return np.maximum(loadingf, loadingt)
43
+
44
+
45
+ def create_admittance_matrix(bus_params, edge_params, sn_mva=100):
46
+ """
47
+ TODO Docstrings
48
+
49
+ Parameters:
50
+ - bus_params: pandas df
51
+ - edge_params: pandas df
52
+
53
+ """
54
+
55
+ base_kv = bus_params["baseKV"].values
56
+
57
+ # Extract from-bus and to-bus indices for each branch
58
+
59
+ f = edge_params["from_bus"].values.astype(np.int32)
60
+ t = edge_params["to_bus"].values.astype(np.int32)
61
+
62
+ # Extract branch admittance coefficients
63
+ Yff = edge_params["Yff_r"].values + 1j * edge_params["Yff_i"].values
64
+ Yft = edge_params["Yft_r"].values + 1j * edge_params["Yft_i"].values
65
+ Ytf = edge_params["Ytf_r"].values + 1j * edge_params["Ytf_i"].values
66
+ Ytt = edge_params["Ytt_r"].values + 1j * edge_params["Ytt_i"].values
67
+
68
+ # Get base voltages for the from and to buses (for kA conversion)
69
+ Vf_base_kV = base_kv[f]
70
+ Vt_base_kV = base_kv[t]
71
+
72
+ nl = edge_params.shape[0]
73
+ nb = bus_params.shape[0]
74
+
75
+ # i = [0, 1, ..., nl-1, 0, 1, ..., nl-1], used for constructing Yf and Yt
76
+ i = np.hstack([np.arange(nl), np.arange(nl)])
77
+
78
+ # Construct from-end admittance matrix Yf using the linear combination:
79
+ # Yf[b, :] = y_ff_b * e_f + y_ft_b * e_t
80
+ Yf = csr_matrix((np.hstack([Yff, Yft]), (i, np.hstack([f, t]))), shape=(nl, nb))
81
+ Yt = csr_matrix((np.hstack([Ytf, Ytt]), (i, np.hstack([f, t]))), shape=(nl, nb))
82
+
83
+ return Yf, Yt, Vf_base_kV, Vt_base_kV
@@ -0,0 +1,42 @@
1
+ def compute_cm_metrics(y_test, y_pred, model_name, label_plot):
2
+ """
3
+ Compute confusion matrix (TP,FP,TN,FN) for predicted overleads along with their respective rates and accuracy metric.
4
+
5
+ Parameters:
6
+ - y_pred: predicted overlads
7
+ - y_test: ground truth overloads
8
+ - prediction_dir:
9
+ - label_plot:
10
+ """
11
+
12
+ TP = (y_test & y_pred).sum()
13
+ FP = ((~y_test) & y_pred).sum()
14
+ TN = ((~y_test) & (~y_pred)).sum()
15
+ FN = (y_test & (~y_pred)).sum()
16
+
17
+ # accuracy
18
+ accuracy = (TP + TN) / (TP + FP + TN + FN)
19
+ print(f"Accuracy: {accuracy:.3f}")
20
+
21
+ TPR = TP / (TP + FN)
22
+ FPR = FP / (FP + TN)
23
+ TNR = TN / (TN + FP)
24
+ FNR = FN / (FN + TP)
25
+ # TODO change text to fit both overloadings and voltage violations
26
+ print("Confusion Matrix:")
27
+ print(f"TP: {TP}, FP: {FP}, TN: {TN}, FN: {FN}")
28
+ print(
29
+ f"GridFM\nTPR: {TPR:.3f} (percentage of overloadings correctly predicted)\nFPR: {FPR:.3f} (percentage of non-overloadings predicted as overloadings)\nTNR: {TNR:.2f}\nFNR: {FNR:.2f}",
30
+ )
31
+ with open(f"metrics_overloading_{model_name}.txt", "w") as f:
32
+ f.write(f"Accuracy: {accuracy:.3f}\n")
33
+ f.write("Confusion Matrix:\n")
34
+ f.write(f"TP: {TP}, FP: {FP}, TN: {TN}, FN: {FN}\n")
35
+ f.write(f"{label_plot} Metrics:\n")
36
+ f.write(f"TPR: {TPR:.5f} (percentage of overloadings correctly predicted)\n")
37
+ f.write(
38
+ f"FPR: {FPR:.5f} (percentage of non-overloadings predicted as overloadings)\n",
39
+ )
40
+ f.write(f"TNR: {TNR:.5f}\n")
41
+ f.write(f"FNR: {FNR:.5f}\n")
42
+ return TP, FP, TN, FN
@@ -0,0 +1,513 @@
1
+ from gridfm_graphkit.training.loss import PBELoss
2
+ from gridfm_graphkit.datasets.globals import PQ, PV, REF
3
+
4
+ import networkx as nx
5
+ import matplotlib.pyplot as plt
6
+ from matplotlib.colors import LogNorm
7
+ from scipy.stats import pearsonr
8
+ import seaborn as sns
9
+ import numpy as np
10
+ import copy
11
+
12
+
13
+ def visualize_error(data_point, output, node_normalizer):
14
+ loss = PBELoss(visualization=True)
15
+
16
+ loss_dict = loss(
17
+ output,
18
+ data_point.y,
19
+ data_point.edge_index,
20
+ data_point.edge_attr,
21
+ data_point.mask,
22
+ )
23
+ active_loss = loss_dict["Nodal Active Power Loss in p.u."]
24
+ active_loss = active_loss.cpu() * node_normalizer.baseMVA
25
+
26
+ # Create a graph
27
+ G = nx.Graph()
28
+ edges = [
29
+ (u, v)
30
+ for u, v in zip(
31
+ data_point.edge_index[0].tolist(),
32
+ data_point.edge_index[1].tolist(),
33
+ )
34
+ if u != v
35
+ ]
36
+ G.add_edges_from(edges)
37
+
38
+ # Assign labels based on node type
39
+ node_shapes = {"REF": "s", "PV": "H", "PQ": "o"}
40
+ num_nodes = data_point.x.shape[0]
41
+ mask_PQ = data_point.x[:, PQ] == 1
42
+ mask_PV = data_point.x[:, PV] == 1
43
+ mask_REF = data_point.x[:, REF] == 1
44
+ node_labels = {}
45
+ for i in range(num_nodes):
46
+ if mask_REF[i]:
47
+ node_labels[i] = "REF"
48
+ elif mask_PV[i]:
49
+ node_labels[i] = "PV"
50
+ elif mask_PQ[i]:
51
+ node_labels[i] = "PQ"
52
+
53
+ # Set node positions
54
+ pos = nx.spring_layout(G, seed=42)
55
+
56
+ # Define colormap
57
+ cmap = plt.cm.viridis
58
+ vmin = min(active_loss)
59
+ vmax = max(active_loss)
60
+ norm = plt.Normalize(vmin=vmin, vmax=vmax)
61
+
62
+ # Create a figure and axis
63
+ fig, ax = plt.subplots(figsize=(13, 7))
64
+
65
+ # Draw nodes with heatmap coloring
66
+ for node_type, shape in node_shapes.items():
67
+ nodes = [i for i in node_labels if node_labels[i] == node_type]
68
+ nx.draw_networkx_nodes(
69
+ G,
70
+ pos,
71
+ nodelist=nodes,
72
+ node_color=[active_loss[i] for i in nodes],
73
+ cmap=cmap,
74
+ node_size=800,
75
+ ax=ax,
76
+ vmin=vmin,
77
+ vmax=vmax,
78
+ node_shape=shape,
79
+ )
80
+
81
+ # Draw edges
82
+ nx.draw_networkx_edges(G, pos, edge_color="gray", alpha=0.5, ax=ax)
83
+
84
+ # Draw labels (node types)
85
+ nx.draw_networkx_labels(
86
+ G,
87
+ pos,
88
+ labels=node_labels,
89
+ font_size=10,
90
+ font_color="white",
91
+ font_weight="bold",
92
+ ax=ax,
93
+ )
94
+
95
+ # Add colorbar
96
+ cbar = plt.colorbar(plt.cm.ScalarMappable(cmap=cmap, norm=norm), ax=ax)
97
+ cbar.set_label("Active Power Residuals (MW)", fontsize=12)
98
+ cbar.ax.tick_params(labelsize=12)
99
+
100
+ for spine in ax.spines.values():
101
+ spine.set_linewidth(2) # Adjust thickness here (e.g., 2 or any value)
102
+
103
+ # Show plot
104
+ plt.title("Nodal Active Power Residuals", fontsize=14, fontweight="bold")
105
+ plt.show()
106
+
107
+
108
+ def visualize_quantity_heatmap(
109
+ data_point,
110
+ output,
111
+ quantity,
112
+ quantity_name,
113
+ unit,
114
+ node_normalizer,
115
+ ):
116
+ """
117
+ Visualizes a heatmap of a specified quantity (VM, PD, QD, PG, QG, VA) for a given dataset and model.
118
+
119
+ Parameters:
120
+ data_point: Power grid data.
121
+ model: The trained model used for inference.
122
+ quantity: The quantity to visualize (e.g., VM, PD, QD, PG, QG, VA).
123
+ """
124
+ data_point = copy.deepcopy(data_point)
125
+ output = copy.deepcopy(output)
126
+ mask_PQ = data_point.x[:, PQ] == 1
127
+ mask_PV = data_point.x[:, PV] == 1
128
+ mask_REF = data_point.x[:, REF] == 1
129
+
130
+ output = node_normalizer.inverse_transform(output)
131
+ denormalized_gt = node_normalizer.inverse_transform(data_point.y)
132
+
133
+ gt_values = denormalized_gt[:, quantity]
134
+ predicted_values = output[:, quantity]
135
+ predicted_values[~data_point.mask[:, quantity]] = denormalized_gt[
136
+ ~data_point.mask[:, quantity],
137
+ quantity,
138
+ ]
139
+
140
+ num_nodes = data_point.x.shape[0]
141
+
142
+ node_shapes = {"REF": "s", "PV": "H", "PQ": "o"}
143
+
144
+ # Create graph
145
+ G = nx.Graph()
146
+ edges = [
147
+ (u, v)
148
+ for u, v in zip(
149
+ data_point.edge_index[0].tolist(),
150
+ data_point.edge_index[1].tolist(),
151
+ )
152
+ if u != v
153
+ ]
154
+ G.add_edges_from(edges)
155
+
156
+ node_labels = {}
157
+ for i in range(num_nodes):
158
+ if mask_REF[i]:
159
+ node_labels[i] = "REF"
160
+ elif mask_PV[i]:
161
+ node_labels[i] = "PV"
162
+ elif mask_PQ[i]:
163
+ node_labels[i] = "PQ"
164
+
165
+ pos = nx.spring_layout(G, seed=42)
166
+ cmap = plt.cm.viridis
167
+ vmin = min(predicted_values)
168
+ vmax = max(predicted_values)
169
+ norm = plt.Normalize(vmin=vmin, vmax=vmax)
170
+
171
+ masked_node_indices = np.where(data_point.mask[:, quantity].cpu())[0]
172
+
173
+ # Create subplots for side-by-side layout (3 plots)
174
+ fig, axes = plt.subplots(1, 3, figsize=(22, 8))
175
+
176
+ # First plot (ground truth values)
177
+ ax = axes[0]
178
+ for node_type, shape in node_shapes.items():
179
+ nodes = [i for i in node_labels if node_labels[i] == node_type]
180
+ node_size = 390 if node_type == "REF" else 600
181
+ nx.draw_networkx_nodes(
182
+ G,
183
+ pos,
184
+ nodelist=nodes,
185
+ node_color=[gt_values[i] for i in nodes],
186
+ cmap=cmap,
187
+ node_size=node_size,
188
+ ax=ax,
189
+ vmin=vmin,
190
+ vmax=vmax,
191
+ node_shape=shape,
192
+ )
193
+
194
+ nx.draw_networkx_edges(G, pos, edge_color="gray", alpha=0.5, ax=ax, width=2)
195
+ nx.draw_networkx_labels(
196
+ G,
197
+ pos,
198
+ labels=node_labels,
199
+ font_size=10,
200
+ font_color="white",
201
+ font_weight="bold",
202
+ ax=ax,
203
+ )
204
+ ax.set_title(f"Ground truth {quantity_name}", fontsize=14, fontweight="bold")
205
+
206
+ for spine in ax.spines.values():
207
+ spine.set_linewidth(2) # Adjust thickness
208
+
209
+ # Second plot (with masked nodes in gray)
210
+ ax = axes[1]
211
+ for node_type, shape in node_shapes.items():
212
+ nodes = [i for i in node_labels if node_labels[i] == node_type]
213
+ node_size = 390 if node_type == "REF" else 600
214
+ nx.draw_networkx_nodes(
215
+ G,
216
+ pos,
217
+ nodelist=nodes,
218
+ node_color=[gt_values[i] for i in nodes],
219
+ cmap=cmap,
220
+ node_size=node_size,
221
+ ax=ax,
222
+ vmin=vmin,
223
+ vmax=vmax,
224
+ node_shape=shape,
225
+ )
226
+
227
+ nx.draw_networkx_nodes(
228
+ G,
229
+ pos,
230
+ nodelist=masked_node_indices,
231
+ node_color="#D3D3D3",
232
+ node_size=750,
233
+ ax=ax,
234
+ )
235
+ nx.draw_networkx_edges(G, pos, edge_color="gray", alpha=0.5, ax=ax, width=2)
236
+ nx.draw_networkx_labels(
237
+ G,
238
+ pos,
239
+ labels=node_labels,
240
+ font_size=10,
241
+ font_color="white",
242
+ font_weight="bold",
243
+ ax=ax,
244
+ )
245
+ ax.set_title(f"Masked {quantity_name}", fontsize=14, fontweight="bold")
246
+
247
+ for spine in ax.spines.values():
248
+ spine.set_linewidth(2) # Adjust thickness
249
+
250
+ # Third plot (predicted values without masking)
251
+ ax = axes[2]
252
+ for node_type, shape in node_shapes.items():
253
+ nodes = [i for i in node_labels if node_labels[i] == node_type]
254
+ node_size = 390 if node_type == "REF" else 600
255
+ nx.draw_networkx_nodes(
256
+ G,
257
+ pos,
258
+ nodelist=nodes,
259
+ node_color=[predicted_values[i] for i in nodes],
260
+ cmap=cmap,
261
+ node_size=node_size,
262
+ ax=ax,
263
+ vmin=vmin,
264
+ vmax=vmax,
265
+ node_shape=shape,
266
+ )
267
+
268
+ nx.draw_networkx_edges(G, pos, edge_color="gray", alpha=0.5, ax=ax, width=2)
269
+ nx.draw_networkx_labels(
270
+ G,
271
+ pos,
272
+ labels=node_labels,
273
+ font_size=10,
274
+ font_color="white",
275
+ font_weight="bold",
276
+ ax=ax,
277
+ )
278
+ ax.set_title(f"Reconstructed {quantity_name}", fontsize=14, fontweight="bold")
279
+
280
+ for spine in ax.spines.values():
281
+ spine.set_linewidth(2) # Adjust thickness
282
+
283
+ # Colorbar placement
284
+ cbar_ax = fig.add_axes([0.93, 0.1, 0.02, 0.8])
285
+ cbar = plt.colorbar(plt.cm.ScalarMappable(cmap=cmap, norm=norm), cax=cbar_ax)
286
+ cbar.set_label(f"{quantity_name} ({unit})", fontsize=12)
287
+ cbar.ax.tick_params(labelsize=12)
288
+
289
+ plt.subplots_adjust(right=0.9)
290
+ plt.show()
291
+
292
+
293
+ def plot_mass_correlation_density(
294
+ true_vals,
295
+ gfm_vals,
296
+ model_name,
297
+ label_plot,
298
+ x_max=2,
299
+ y_max=3,
300
+ ):
301
+ """
302
+ TODO docstring
303
+
304
+ """
305
+ # TODO check if these parameters need to be passed by func or default behavior
306
+ vmin = 1
307
+ x_min = 0
308
+ y_min = 0
309
+ bin_width = 0.01 # consistent bin width for both plots
310
+
311
+ # Generate consistent bins
312
+ x_bins = np.arange(x_min, x_max + bin_width, bin_width)
313
+ y_bins = np.arange(y_min, y_max + bin_width, bin_width)
314
+
315
+ # estimate vmax on mean count of elements across bins
316
+ counts, _, _ = np.histogram2d(true_vals, gfm_vals, bins=[x_bins, y_bins])
317
+
318
+ counts[counts == 0] = np.nan
319
+ means = np.nanmean(counts)
320
+ std = np.nanstd(counts)
321
+ vmax = means + 3 * std
322
+
323
+ # Pearson correlations
324
+ corr_gfm, _ = pearsonr(true_vals, gfm_vals)
325
+
326
+ # Create figure with shared x-axis
327
+ fig, ax1 = plt.subplots(figsize=(8, 6), dpi=400)
328
+
329
+ # --- GridFM Mass Correlation ---
330
+ h1 = ax1.hist2d(
331
+ true_vals,
332
+ gfm_vals,
333
+ bins=[x_bins, y_bins],
334
+ norm=LogNorm(vmin=vmin, vmax=vmax),
335
+ cmap="inferno",
336
+ )
337
+ ax1.axvline(1, color="black", linestyle="--", linewidth=2.0)
338
+ ax1.axhline(1, color="black", linestyle="--", linewidth=2.0)
339
+ ax1.plot([0, 5], [0, 5], "k--", linewidth=0.5)
340
+ ax1.set_xlabel("True Loadings", fontsize=12)
341
+ ax1.set_ylabel("Predicted Loadings", fontsize=12)
342
+ ax1.set_title(label_plot, fontsize=14)
343
+ ax1.text(
344
+ x_max - 1.5,
345
+ 0.93,
346
+ f"r = {corr_gfm:.5f}",
347
+ transform=ax1.transAxes,
348
+ fontsize=13,
349
+ weight="bold",
350
+ )
351
+
352
+ # Colorbar
353
+ cbar = fig.colorbar(h1[3], ax=ax1, pad=0.02)
354
+ cbar.set_label("Number of samples", fontsize=10)
355
+
356
+ # Style adjustments
357
+ ax1.set_xlim(x_min, x_max)
358
+ ax1.set_ylim(y_min, y_max)
359
+ ax1.grid(True, linewidth=0.3)
360
+ ax1.tick_params(axis="both", labelsize=10)
361
+
362
+ plt.tight_layout()
363
+ plt.savefig(f"mass_correlation_density_{model_name}.png", bbox_inches="tight")
364
+ plt.show()
365
+
366
+
367
+ def plot_cm(TN, FP, FN, TP, model_name, label_plot):
368
+ """
369
+ TODO docstring
370
+ """
371
+ cm = np.array([[TN, FP], [FN, TP]])
372
+
373
+ cm_labels = ["Non-overload", "Overload"]
374
+
375
+ fig_cm, ax_cm = plt.subplots(figsize=(6, 6))
376
+
377
+ sns.heatmap(
378
+ cm,
379
+ annot=True,
380
+ fmt="d",
381
+ cbar=False,
382
+ square=True,
383
+ linewidths=0.5,
384
+ cmap="Blues",
385
+ xticklabels=cm_labels,
386
+ yticklabels=cm_labels,
387
+ ax=ax_cm,
388
+ annot_kws={"size": 14},
389
+ )
390
+
391
+ ax_cm.set_xlabel("Predicted", fontsize=12)
392
+ ax_cm.set_ylabel("True", fontsize=12)
393
+ ax_cm.set_title(f"Confusion Matrix {label_plot}", fontsize=14)
394
+ ax_cm.tick_params(axis="both", labelsize=12)
395
+
396
+ plt.tight_layout()
397
+ plt.savefig(f"confusion_matrix_overload_{model_name}.png", bbox_inches="tight")
398
+ plt.show()
399
+
400
+
401
+ def plot_loading_predictions(
402
+ loadings_pred,
403
+ loadings_dc,
404
+ loadings_gt,
405
+ prediction_dir,
406
+ label_plot,
407
+ ):
408
+ """
409
+ TODO docstrings
410
+ """
411
+ plt.hist(
412
+ loadings_pred,
413
+ alpha=0.5,
414
+ label=label_plot,
415
+ density=True,
416
+ bins=100,
417
+ )
418
+ plt.hist(loadings_dc, alpha=0.5, label="DC Solver", density=True, bins=100)
419
+ plt.hist(loadings_gt, alpha=0.5, label="Ground truth", density=True, bins=100)
420
+
421
+ plt.xlabel("Loading Values")
422
+ plt.ylabel("Density")
423
+ plt.yscale("log")
424
+ plt.legend()
425
+
426
+ plt.savefig(f"distribution_loading_predictions_{prediction_dir}.png")
427
+ plt.show()
428
+
429
+
430
+ def plot_mass_correlation_density_voltage(
431
+ pf_node,
432
+ prediction_dir,
433
+ label_plot,
434
+ x_min=0.85,
435
+ y_min=0.85,
436
+ x_max=1.15,
437
+ y_max=1.15,
438
+ ):
439
+ """
440
+ TODO docstrings
441
+ TODO refactor if we pass by parameters a few more plot deets we can use plot_mass_correlation_density for both
442
+
443
+ """
444
+ # Get the global min and max for color scaling (avoid log(0) by setting min to at least 1)
445
+ vmin = 1
446
+ bin_width = 0.001 # consistent bin width for both plots
447
+
448
+ # Generate consistent bins
449
+ x_bins = np.arange(x_min, x_max + bin_width, bin_width)
450
+ y_bins = np.arange(y_min, y_max + bin_width, bin_width)
451
+
452
+ # estimate vmax on mean count of elements across bins
453
+ counts, _, _ = np.histogram2d(
454
+ pf_node["Vm"],
455
+ pf_node["Vm_pred_corrected"],
456
+ bins=[x_bins, y_bins],
457
+ )
458
+
459
+ counts[counts == 0] = np.nan
460
+ means = np.nanmean(counts)
461
+ std = np.nanstd(counts)
462
+ vmax = means + 3 * std
463
+
464
+ # Pearson correlations
465
+ corr_vm, _ = pearsonr(pf_node["Vm"], pf_node["Vm_pred_corrected"])
466
+
467
+ # Create figure with shared x-axis
468
+ fig, ax1 = plt.subplots(figsize=(8, 6), dpi=400)
469
+
470
+ # --- GridFM Mass Correlation ---
471
+ h1 = ax1.hist2d(
472
+ pf_node["Vm"],
473
+ pf_node["Vm_pred_corrected"],
474
+ bins=[x_bins, y_bins],
475
+ norm=LogNorm(vmin=vmin, vmax=vmax),
476
+ cmap="inferno",
477
+ )
478
+ ax1.axvline(x_min + 0.05, color="black", linestyle="--", linewidth=2.0)
479
+ ax1.axhline(y_min + 0.05, color="black", linestyle="--", linewidth=2.0)
480
+ ax1.axvline(x_max - 0.05, color="black", linestyle="--", linewidth=2.0)
481
+ ax1.axhline(y_max - 0.05, color="black", linestyle="--", linewidth=2.0)
482
+
483
+ ax1.plot([0, 5], [0, 5], "k--", linewidth=0.5)
484
+ ax1.set_xlabel("True Voltage Magnitude", fontsize=12)
485
+ ax1.set_ylabel("Predicted Voltage magnitude", fontsize=12)
486
+ ax1.set_title(label_plot, fontsize=14)
487
+ ax1.text(
488
+ 0.5,
489
+ 0.95,
490
+ f"r = {corr_vm:.5f}",
491
+ transform=ax1.transAxes,
492
+ fontsize=13,
493
+ weight="bold",
494
+ ha="center",
495
+ va="top",
496
+ )
497
+
498
+ # Colorbar
499
+ cbar = fig.colorbar(h1[3], ax=ax1, pad=0.02)
500
+ cbar.set_label("Number of samples", fontsize=10)
501
+
502
+ # Style adjustments
503
+ ax1.set_xlim(x_min, x_max)
504
+ ax1.set_ylim(y_min, y_max)
505
+ ax1.grid(True, linewidth=0.3)
506
+ ax1.tick_params(axis="both", labelsize=10)
507
+
508
+ plt.tight_layout()
509
+ plt.savefig(
510
+ f"mass_correlation_density_voltage_{prediction_dir}.png",
511
+ bbox_inches="tight",
512
+ )
513
+ plt.show()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gridfm-graphkit
3
- Version: 0.0.4
3
+ Version: 0.0.5
4
4
  Summary: Grid Foundation Model
5
5
  Author-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>, Alban Puech <apuech@seas.harvard.edu>, Tamara Govindasamy <tamara.govindasamy@ibm.com>, Mangaliso Mngomezulu <mngomezulum@ibm.com>, Etienne Vos <etienne.vos@ibm.com>, Celia Cintas <celia.cintas@ibm.com>, Jonas Weiss <jwe@zurich.ibm.com>
6
6
  Maintainer-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>
@@ -26,6 +26,7 @@ Requires-Dist: torch-geometric>=2.6.1
26
26
  Requires-Dist: torchaudio>=2.7.1
27
27
  Requires-Dist: torchvision>=0.22.1
28
28
  Requires-Dist: lightning
29
+ Requires-Dist: seaborn
29
30
  Provides-Extra: dev
30
31
  Requires-Dist: mkdocs-material; extra == "dev"
31
32
  Requires-Dist: mkdocstrings[python]; extra == "dev"
@@ -13,6 +13,7 @@ gridfm_graphkit.egg-info/top_level.txt
13
13
  gridfm_graphkit/datasets/__init__.py
14
14
  gridfm_graphkit/datasets/globals.py
15
15
  gridfm_graphkit/datasets/normalizers.py
16
+ gridfm_graphkit/datasets/postprocessing.py
16
17
  gridfm_graphkit/datasets/powergrid_datamodule.py
17
18
  gridfm_graphkit/datasets/powergrid_dataset.py
18
19
  gridfm_graphkit/datasets/transforms.py
@@ -29,6 +30,7 @@ gridfm_graphkit/training/__init__.py
29
30
  gridfm_graphkit/training/callbacks.py
30
31
  gridfm_graphkit/training/loss.py
31
32
  gridfm_graphkit/utils/__init__.py
33
+ gridfm_graphkit/utils/utils.py
32
34
  gridfm_graphkit/utils/visualization.py
33
35
  tests/test_data_module.py
34
36
  tests/test_full_pipeline.py
@@ -10,6 +10,7 @@ torch-geometric>=2.6.1
10
10
  torchaudio>=2.7.1
11
11
  torchvision>=0.22.1
12
12
  lightning
13
+ seaborn
13
14
 
14
15
  [dev]
15
16
  mkdocs-material
@@ -9,7 +9,7 @@ namespaces = false
9
9
  [project]
10
10
  name = "gridfm-graphkit"
11
11
  description = "Grid Foundation Model"
12
- version = "0.0.4"
12
+ version = "0.0.5"
13
13
  readme = "README.md"
14
14
  license = "Apache-2.0"
15
15
  requires-python = ">=3.10,<3.13"
@@ -52,6 +52,7 @@ dependencies = [
52
52
  "torchaudio>=2.7.1",
53
53
  "torchvision>=0.22.1",
54
54
  "lightning",
55
+ "seaborn",
55
56
  ]
56
57
 
57
58
  [project.optional-dependencies]
@@ -1,99 +0,0 @@
1
- import networkx as nx
2
- from gridfm_graphkit.training.loss import PBELoss
3
- from gridfm_graphkit.datasets.globals import PQ, PV, REF
4
- import matplotlib.pyplot as plt
5
-
6
-
7
- def visualize_error(data_point, output, node_normalizer):
8
- loss = PBELoss(visualization=True)
9
-
10
- loss_dict = loss(
11
- output,
12
- data_point.y,
13
- data_point.edge_index,
14
- data_point.edge_attr,
15
- data_point.mask,
16
- )
17
- active_loss = loss_dict["Nodal Active Power Loss in p.u."]
18
- active_loss = active_loss.cpu() * node_normalizer.baseMVA
19
-
20
- # Create a graph
21
- G = nx.Graph()
22
- edges = [
23
- (u, v)
24
- for u, v in zip(
25
- data_point.edge_index[0].tolist(),
26
- data_point.edge_index[1].tolist(),
27
- )
28
- if u != v
29
- ]
30
- G.add_edges_from(edges)
31
-
32
- # Assign labels based on node type
33
- node_shapes = {"REF": "s", "PV": "H", "PQ": "o"}
34
- num_nodes = data_point.x.shape[0]
35
- mask_PQ = data_point.x[:, PQ] == 1
36
- mask_PV = data_point.x[:, PV] == 1
37
- mask_REF = data_point.x[:, REF] == 1
38
- node_labels = {}
39
- for i in range(num_nodes):
40
- if mask_REF[i]:
41
- node_labels[i] = "REF"
42
- elif mask_PV[i]:
43
- node_labels[i] = "PV"
44
- elif mask_PQ[i]:
45
- node_labels[i] = "PQ"
46
-
47
- # Set node positions
48
- pos = nx.spring_layout(G, seed=42)
49
-
50
- # Define colormap
51
- cmap = plt.cm.viridis
52
- vmin = min(active_loss)
53
- vmax = max(active_loss)
54
- norm = plt.Normalize(vmin=vmin, vmax=vmax)
55
-
56
- # Create a figure and axis
57
- fig, ax = plt.subplots(figsize=(13, 7))
58
-
59
- # Draw nodes with heatmap coloring
60
- for node_type, shape in node_shapes.items():
61
- nodes = [i for i in node_labels if node_labels[i] == node_type]
62
- nx.draw_networkx_nodes(
63
- G,
64
- pos,
65
- nodelist=nodes,
66
- node_color=[active_loss[i] for i in nodes],
67
- cmap=cmap,
68
- node_size=800,
69
- ax=ax,
70
- vmin=vmin,
71
- vmax=vmax,
72
- node_shape=shape,
73
- )
74
-
75
- # Draw edges
76
- nx.draw_networkx_edges(G, pos, edge_color="gray", alpha=0.5, ax=ax)
77
-
78
- # Draw labels (node types)
79
- nx.draw_networkx_labels(
80
- G,
81
- pos,
82
- labels=node_labels,
83
- font_size=10,
84
- font_color="white",
85
- font_weight="bold",
86
- ax=ax,
87
- )
88
-
89
- # Add colorbar
90
- cbar = plt.colorbar(plt.cm.ScalarMappable(cmap=cmap, norm=norm), ax=ax)
91
- cbar.set_label("Active Power Residuals (MW)", fontsize=12)
92
- cbar.ax.tick_params(labelsize=12)
93
-
94
- for spine in ax.spines.values():
95
- spine.set_linewidth(2) # Adjust thickness here (e.g., 2 or any value)
96
-
97
- # Show plot
98
- plt.title("Nodal Active Power Residuals", fontsize=14, fontweight="bold")
99
- plt.show()
File without changes