gridfm-graphkit 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gridfm_graphkit/__init__.py +0 -0
- gridfm_graphkit/__main__.py +62 -0
- gridfm_graphkit/cli.py +530 -0
- gridfm_graphkit/datasets/__init__.py +0 -0
- gridfm_graphkit/datasets/data_normalization.py +227 -0
- gridfm_graphkit/datasets/globals.py +19 -0
- gridfm_graphkit/datasets/powergrid.py +192 -0
- gridfm_graphkit/datasets/transforms.py +223 -0
- gridfm_graphkit/datasets/utils.py +65 -0
- gridfm_graphkit/io/__init__.py +0 -0
- gridfm_graphkit/io/param_handler.py +293 -0
- gridfm_graphkit/models/__init__.py +0 -0
- gridfm_graphkit/models/gps_transformer.py +143 -0
- gridfm_graphkit/models/graphTransformer.py +96 -0
- gridfm_graphkit/training/__init__.py +0 -0
- gridfm_graphkit/training/callbacks.py +47 -0
- gridfm_graphkit/training/plugins.py +218 -0
- gridfm_graphkit/training/trainer.py +156 -0
- gridfm_graphkit/utils/__init__.py +0 -0
- gridfm_graphkit/utils/loss.py +198 -0
- gridfm_graphkit/utils/visualization.py +324 -0
- gridfm_graphkit-0.0.1.dist-info/METADATA +163 -0
- gridfm_graphkit-0.0.1.dist-info/RECORD +27 -0
- gridfm_graphkit-0.0.1.dist-info/WHEEL +5 -0
- gridfm_graphkit-0.0.1.dist-info/entry_points.txt +2 -0
- gridfm_graphkit-0.0.1.dist-info/licenses/LICENSE +201 -0
- gridfm_graphkit-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,324 @@
|
|
|
1
|
+
from gridfm_graphkit.utils.loss import PBELoss
|
|
2
|
+
from gridfm_graphkit.datasets.globals import PQ, PV, REF
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import networkx as nx
|
|
6
|
+
import matplotlib.pyplot as plt
|
|
7
|
+
import numpy as np
|
|
8
|
+
import copy
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def visualize_error(data_point, model, baseMVA, device):
|
|
12
|
+
data_point = copy.deepcopy(data_point)
|
|
13
|
+
active_loss = None
|
|
14
|
+
loss = PBELoss(visualization=True)
|
|
15
|
+
|
|
16
|
+
# Inference of one data point
|
|
17
|
+
model.eval()
|
|
18
|
+
with torch.no_grad():
|
|
19
|
+
data_point = data_point.to(device)
|
|
20
|
+
mask_value_expanded = model.mask_value.expand(data_point.x.shape[0], -1)
|
|
21
|
+
data_point.x[:, : data_point.mask.shape[1]][data_point.mask] = (
|
|
22
|
+
mask_value_expanded[data_point.mask]
|
|
23
|
+
)
|
|
24
|
+
output = model(
|
|
25
|
+
data_point.x,
|
|
26
|
+
data_point.pe,
|
|
27
|
+
data_point.edge_index,
|
|
28
|
+
data_point.edge_attr,
|
|
29
|
+
torch.zeros(data_point.x.shape[0], dtype=int).to(device),
|
|
30
|
+
)
|
|
31
|
+
loss_dict = loss(
|
|
32
|
+
output,
|
|
33
|
+
data_point.y,
|
|
34
|
+
data_point.edge_index,
|
|
35
|
+
data_point.edge_attr,
|
|
36
|
+
data_point.mask,
|
|
37
|
+
)
|
|
38
|
+
active_loss = loss_dict["Nodal Active Power Loss in p.u."]
|
|
39
|
+
active_loss = active_loss.cpu() * baseMVA
|
|
40
|
+
|
|
41
|
+
# Create a graph
|
|
42
|
+
G = nx.Graph()
|
|
43
|
+
edges = [
|
|
44
|
+
(u, v)
|
|
45
|
+
for u, v in zip(
|
|
46
|
+
data_point.edge_index[0].tolist(),
|
|
47
|
+
data_point.edge_index[1].tolist(),
|
|
48
|
+
)
|
|
49
|
+
if u != v
|
|
50
|
+
]
|
|
51
|
+
G.add_edges_from(edges)
|
|
52
|
+
|
|
53
|
+
# Assign labels based on node type
|
|
54
|
+
node_shapes = {"REF": "s", "PV": "H", "PQ": "o"}
|
|
55
|
+
num_nodes = data_point.x.shape[0]
|
|
56
|
+
mask_PQ = data_point.x[:, PQ] == 1
|
|
57
|
+
mask_PV = data_point.x[:, PV] == 1
|
|
58
|
+
mask_REF = data_point.x[:, REF] == 1
|
|
59
|
+
node_labels = {}
|
|
60
|
+
for i in range(num_nodes):
|
|
61
|
+
if mask_REF[i]:
|
|
62
|
+
node_labels[i] = "REF"
|
|
63
|
+
elif mask_PV[i]:
|
|
64
|
+
node_labels[i] = "PV"
|
|
65
|
+
elif mask_PQ[i]:
|
|
66
|
+
node_labels[i] = "PQ"
|
|
67
|
+
|
|
68
|
+
# Set node positions
|
|
69
|
+
pos = nx.spring_layout(G, seed=42)
|
|
70
|
+
|
|
71
|
+
# Define colormap
|
|
72
|
+
cmap = plt.cm.viridis
|
|
73
|
+
vmin = min(active_loss)
|
|
74
|
+
vmax = max(active_loss)
|
|
75
|
+
norm = plt.Normalize(vmin=vmin, vmax=vmax)
|
|
76
|
+
|
|
77
|
+
# Create a figure and axis
|
|
78
|
+
fig, ax = plt.subplots(figsize=(13, 7))
|
|
79
|
+
|
|
80
|
+
# Draw nodes with heatmap coloring
|
|
81
|
+
for node_type, shape in node_shapes.items():
|
|
82
|
+
nodes = [i for i in node_labels if node_labels[i] == node_type]
|
|
83
|
+
nx.draw_networkx_nodes(
|
|
84
|
+
G,
|
|
85
|
+
pos,
|
|
86
|
+
nodelist=nodes,
|
|
87
|
+
node_color=[active_loss[i] for i in nodes],
|
|
88
|
+
cmap=cmap,
|
|
89
|
+
node_size=800,
|
|
90
|
+
ax=ax,
|
|
91
|
+
vmin=vmin,
|
|
92
|
+
vmax=vmax,
|
|
93
|
+
node_shape=shape,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Draw edges
|
|
97
|
+
nx.draw_networkx_edges(G, pos, edge_color="gray", alpha=0.5, ax=ax)
|
|
98
|
+
|
|
99
|
+
# Draw labels (node types)
|
|
100
|
+
nx.draw_networkx_labels(
|
|
101
|
+
G,
|
|
102
|
+
pos,
|
|
103
|
+
labels=node_labels,
|
|
104
|
+
font_size=10,
|
|
105
|
+
font_color="white",
|
|
106
|
+
font_weight="bold",
|
|
107
|
+
ax=ax,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# Add colorbar
|
|
111
|
+
cbar = plt.colorbar(plt.cm.ScalarMappable(cmap=cmap, norm=norm), ax=ax)
|
|
112
|
+
cbar.set_label("Active Power Residuals (MW)", fontsize=12)
|
|
113
|
+
cbar.ax.tick_params(labelsize=12)
|
|
114
|
+
|
|
115
|
+
for spine in ax.spines.values():
|
|
116
|
+
spine.set_linewidth(2) # Adjust thickness here (e.g., 2 or any value)
|
|
117
|
+
|
|
118
|
+
# Show plot
|
|
119
|
+
plt.title("Nodal Active Power Residuals", fontsize=14, fontweight="bold")
|
|
120
|
+
plt.show()
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def visualize_quantity_heatmap(
|
|
124
|
+
data_point,
|
|
125
|
+
model,
|
|
126
|
+
quantity,
|
|
127
|
+
quantity_name,
|
|
128
|
+
unit,
|
|
129
|
+
node_normalizer,
|
|
130
|
+
cmap,
|
|
131
|
+
device,
|
|
132
|
+
):
|
|
133
|
+
"""
|
|
134
|
+
Visualizes a heatmap of a specified quantity (VM, PD, QD, PG, QG, VA) for a given dataset and model.
|
|
135
|
+
|
|
136
|
+
Parameters:
|
|
137
|
+
data_point: Power grid data.
|
|
138
|
+
model: The trained model used for inference.
|
|
139
|
+
quantity: The quantity to visualize (e.g., VM, PD, QD, PG, QG, VA).
|
|
140
|
+
"""
|
|
141
|
+
data_point = copy.deepcopy(data_point)
|
|
142
|
+
mask_PQ = data_point.x[:, PQ] == 1
|
|
143
|
+
mask_PV = data_point.x[:, PV] == 1
|
|
144
|
+
mask_REF = data_point.x[:, REF] == 1
|
|
145
|
+
gt_values = data_point.y[:, quantity] # Extract ground truth values
|
|
146
|
+
|
|
147
|
+
# Inference of one data point
|
|
148
|
+
model.eval()
|
|
149
|
+
with torch.no_grad():
|
|
150
|
+
data_point = data_point.to(device)
|
|
151
|
+
mask_value_expanded = model.mask_value.expand(data_point.x.shape[0], -1)
|
|
152
|
+
data_point.x[:, : data_point.mask.shape[1]][data_point.mask] = (
|
|
153
|
+
mask_value_expanded[data_point.mask]
|
|
154
|
+
)
|
|
155
|
+
output = model(
|
|
156
|
+
data_point.x,
|
|
157
|
+
data_point.pe,
|
|
158
|
+
data_point.edge_index,
|
|
159
|
+
data_point.edge_attr,
|
|
160
|
+
torch.zeros(data_point.x.shape[0], dtype=int).to(device),
|
|
161
|
+
)
|
|
162
|
+
output = node_normalizer.inverse_transform(output)
|
|
163
|
+
denormalized_gt = node_normalizer.inverse_transform(data_point.y)
|
|
164
|
+
|
|
165
|
+
gt_values = denormalized_gt[:, quantity]
|
|
166
|
+
predicted_values = output[:, quantity]
|
|
167
|
+
predicted_values[~data_point.mask[:, quantity]] = denormalized_gt[
|
|
168
|
+
~data_point.mask[:, quantity],
|
|
169
|
+
quantity,
|
|
170
|
+
]
|
|
171
|
+
|
|
172
|
+
num_nodes = data_point.x.shape[0]
|
|
173
|
+
predicted_values = predicted_values.cpu()
|
|
174
|
+
gt_values = gt_values.cpu()
|
|
175
|
+
|
|
176
|
+
node_shapes = {"REF": "s", "PV": "H", "PQ": "o"}
|
|
177
|
+
|
|
178
|
+
# Create graph
|
|
179
|
+
G = nx.Graph()
|
|
180
|
+
edges = [
|
|
181
|
+
(u, v)
|
|
182
|
+
for u, v in zip(
|
|
183
|
+
data_point.edge_index[0].tolist(),
|
|
184
|
+
data_point.edge_index[1].tolist(),
|
|
185
|
+
)
|
|
186
|
+
if u != v
|
|
187
|
+
]
|
|
188
|
+
G.add_edges_from(edges)
|
|
189
|
+
|
|
190
|
+
node_labels = {}
|
|
191
|
+
for i in range(num_nodes):
|
|
192
|
+
if mask_REF[i]:
|
|
193
|
+
node_labels[i] = "REF"
|
|
194
|
+
elif mask_PV[i]:
|
|
195
|
+
node_labels[i] = "PV"
|
|
196
|
+
elif mask_PQ[i]:
|
|
197
|
+
node_labels[i] = "PQ"
|
|
198
|
+
|
|
199
|
+
pos = nx.spring_layout(G, seed=42)
|
|
200
|
+
cmap = cmap
|
|
201
|
+
vmin = min(predicted_values)
|
|
202
|
+
vmax = max(predicted_values)
|
|
203
|
+
norm = plt.Normalize(vmin=vmin, vmax=vmax)
|
|
204
|
+
|
|
205
|
+
masked_node_indices = np.where(data_point.mask[:, quantity].cpu())[0]
|
|
206
|
+
|
|
207
|
+
# Create subplots for side-by-side layout (3 plots)
|
|
208
|
+
fig, axes = plt.subplots(1, 3, figsize=(22, 8))
|
|
209
|
+
|
|
210
|
+
# First plot (ground truth values)
|
|
211
|
+
ax = axes[0]
|
|
212
|
+
for node_type, shape in node_shapes.items():
|
|
213
|
+
nodes = [i for i in node_labels if node_labels[i] == node_type]
|
|
214
|
+
node_size = 390 if node_type == "REF" else 600
|
|
215
|
+
nx.draw_networkx_nodes(
|
|
216
|
+
G,
|
|
217
|
+
pos,
|
|
218
|
+
nodelist=nodes,
|
|
219
|
+
node_color=[gt_values[i] for i in nodes],
|
|
220
|
+
cmap=cmap,
|
|
221
|
+
node_size=node_size,
|
|
222
|
+
ax=ax,
|
|
223
|
+
vmin=vmin,
|
|
224
|
+
vmax=vmax,
|
|
225
|
+
node_shape=shape,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
nx.draw_networkx_edges(G, pos, edge_color="gray", alpha=0.5, ax=ax, width=2)
|
|
229
|
+
nx.draw_networkx_labels(
|
|
230
|
+
G,
|
|
231
|
+
pos,
|
|
232
|
+
labels=node_labels,
|
|
233
|
+
font_size=10,
|
|
234
|
+
font_color="white",
|
|
235
|
+
font_weight="bold",
|
|
236
|
+
ax=ax,
|
|
237
|
+
)
|
|
238
|
+
ax.set_title(f"Input grid {quantity_name}", fontsize=14, fontweight="bold")
|
|
239
|
+
|
|
240
|
+
for spine in ax.spines.values():
|
|
241
|
+
spine.set_linewidth(2) # Adjust thickness
|
|
242
|
+
|
|
243
|
+
# Second plot (with masked nodes in gray)
|
|
244
|
+
ax = axes[1]
|
|
245
|
+
for node_type, shape in node_shapes.items():
|
|
246
|
+
nodes = [i for i in node_labels if node_labels[i] == node_type]
|
|
247
|
+
node_size = 390 if node_type == "REF" else 600
|
|
248
|
+
nx.draw_networkx_nodes(
|
|
249
|
+
G,
|
|
250
|
+
pos,
|
|
251
|
+
nodelist=nodes,
|
|
252
|
+
node_color=[gt_values[i] for i in nodes],
|
|
253
|
+
cmap=cmap,
|
|
254
|
+
node_size=node_size,
|
|
255
|
+
ax=ax,
|
|
256
|
+
vmin=vmin,
|
|
257
|
+
vmax=vmax,
|
|
258
|
+
node_shape=shape,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
nx.draw_networkx_nodes(
|
|
262
|
+
G,
|
|
263
|
+
pos,
|
|
264
|
+
nodelist=masked_node_indices,
|
|
265
|
+
node_color="#D3D3D3",
|
|
266
|
+
node_size=750,
|
|
267
|
+
ax=ax,
|
|
268
|
+
)
|
|
269
|
+
nx.draw_networkx_edges(G, pos, edge_color="gray", alpha=0.5, ax=ax, width=2)
|
|
270
|
+
nx.draw_networkx_labels(
|
|
271
|
+
G,
|
|
272
|
+
pos,
|
|
273
|
+
labels=node_labels,
|
|
274
|
+
font_size=10,
|
|
275
|
+
font_color="white",
|
|
276
|
+
font_weight="bold",
|
|
277
|
+
ax=ax,
|
|
278
|
+
)
|
|
279
|
+
ax.set_title(f"Masked grid {quantity_name}", fontsize=14, fontweight="bold")
|
|
280
|
+
|
|
281
|
+
for spine in ax.spines.values():
|
|
282
|
+
spine.set_linewidth(2) # Adjust thickness
|
|
283
|
+
|
|
284
|
+
# Third plot (predicted values without masking)
|
|
285
|
+
ax = axes[2]
|
|
286
|
+
for node_type, shape in node_shapes.items():
|
|
287
|
+
nodes = [i for i in node_labels if node_labels[i] == node_type]
|
|
288
|
+
node_size = 390 if node_type == "REF" else 600
|
|
289
|
+
nx.draw_networkx_nodes(
|
|
290
|
+
G,
|
|
291
|
+
pos,
|
|
292
|
+
nodelist=nodes,
|
|
293
|
+
node_color=[predicted_values[i] for i in nodes],
|
|
294
|
+
cmap=cmap,
|
|
295
|
+
node_size=node_size,
|
|
296
|
+
ax=ax,
|
|
297
|
+
vmin=vmin,
|
|
298
|
+
vmax=vmax,
|
|
299
|
+
node_shape=shape,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
nx.draw_networkx_edges(G, pos, edge_color="gray", alpha=0.5, ax=ax, width=2)
|
|
303
|
+
nx.draw_networkx_labels(
|
|
304
|
+
G,
|
|
305
|
+
pos,
|
|
306
|
+
labels=node_labels,
|
|
307
|
+
font_size=10,
|
|
308
|
+
font_color="white",
|
|
309
|
+
font_weight="bold",
|
|
310
|
+
ax=ax,
|
|
311
|
+
)
|
|
312
|
+
ax.set_title(f"Reconstructed grid {quantity_name}", fontsize=14, fontweight="bold")
|
|
313
|
+
|
|
314
|
+
for spine in ax.spines.values():
|
|
315
|
+
spine.set_linewidth(2) # Adjust thickness
|
|
316
|
+
|
|
317
|
+
# Colorbar placement
|
|
318
|
+
cbar_ax = fig.add_axes([0.93, 0.1, 0.02, 0.8])
|
|
319
|
+
cbar = plt.colorbar(plt.cm.ScalarMappable(cmap=cmap, norm=norm), cax=cbar_ax)
|
|
320
|
+
cbar.set_label(f"{quantity_name} ({unit})", fontsize=12)
|
|
321
|
+
cbar.ax.tick_params(labelsize=12)
|
|
322
|
+
|
|
323
|
+
plt.subplots_adjust(right=0.9)
|
|
324
|
+
plt.show()
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: gridfm-graphkit
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Grid Foundation Model
|
|
5
|
+
Author-email: Matteo Mazzonelli <Matteo.Mazzonelli1@ibm.com>, Alban Puech <Alban.Puech1@ibm.com>, Jonas Weiss <jwe@zurich.ibm.com>
|
|
6
|
+
Maintainer-email: Matteo Mazzonelli <Matteo.Mazzonelli1@ibm.com>
|
|
7
|
+
License-Expression: Apache-2.0
|
|
8
|
+
Keywords: electric power grid,foundational model,graph neural networks
|
|
9
|
+
Classifier: Development Status :: 2 - Pre-Alpha
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
11
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
12
|
+
Requires-Python: >=3.12.10
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
License-File: LICENSE
|
|
15
|
+
Requires-Dist: mlflow>=3.1.0
|
|
16
|
+
Requires-Dist: nbformat>=5.10.4
|
|
17
|
+
Requires-Dist: networkx>=3.4.2
|
|
18
|
+
Requires-Dist: numpy>=2.2.6
|
|
19
|
+
Requires-Dist: pandas>=2.3.0
|
|
20
|
+
Requires-Dist: plotly>=6.1.2
|
|
21
|
+
Requires-Dist: pyyaml>=6.0.2
|
|
22
|
+
Requires-Dist: torch>=2.7.1
|
|
23
|
+
Requires-Dist: torch-geometric>=2.6.1
|
|
24
|
+
Requires-Dist: torchaudio>=2.7.1
|
|
25
|
+
Requires-Dist: torchvision>=0.22.1
|
|
26
|
+
Provides-Extra: dev
|
|
27
|
+
Requires-Dist: mkdocs-material; extra == "dev"
|
|
28
|
+
Requires-Dist: mkdocstrings[python]; extra == "dev"
|
|
29
|
+
Requires-Dist: pre-commit>=4.2.0; extra == "dev"
|
|
30
|
+
Requires-Dist: bandit>=1.8.5; extra == "dev"
|
|
31
|
+
Requires-Dist: build; extra == "dev"
|
|
32
|
+
Provides-Extra: test
|
|
33
|
+
Requires-Dist: pytest; extra == "test"
|
|
34
|
+
Requires-Dist: pytest-cov; extra == "test"
|
|
35
|
+
Dynamic: license-file
|
|
36
|
+
|
|
37
|
+
# gridfm-graphkit
|
|
38
|
+
[](https://gridfm.github.io/gridfm-graphkit/)
|
|
39
|
+
|
|
40
|
+
This library is brought to you by the GridFM team to train, finetune and interact with a foundation model for the electric power grid.
|
|
41
|
+
|
|
42
|
+
---
|
|
43
|
+
|
|
44
|
+
<p align="center">
|
|
45
|
+
<img src="docs/figs/pre_training.png" alt="GridFM logo"/>
|
|
46
|
+
<br/>
|
|
47
|
+
</p>
|
|
48
|
+
|
|
49
|
+
# Installation
|
|
50
|
+
|
|
51
|
+
Create a python virtual environment and install the requirements
|
|
52
|
+
```bash
|
|
53
|
+
git clone git@github.com:gridfm/gridfm-graphkit.git
|
|
54
|
+
cd gridfm-graphkit
|
|
55
|
+
python -m venv venv
|
|
56
|
+
source venv/bin/activate
|
|
57
|
+
pip install .
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
Install the package in editable mode during development phase:
|
|
61
|
+
|
|
62
|
+
```bash
|
|
63
|
+
pip install -e .[dev,test]
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
# gridfm-graphkit CLI
|
|
67
|
+
|
|
68
|
+
An interface to train, fine-tune, and evaluate GridFM models using configurable YAML files and MLflow tracking.
|
|
69
|
+
|
|
70
|
+
```bash
|
|
71
|
+
gridfm_graphkit <command> [OPTIONS]
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
Available commands:
|
|
75
|
+
|
|
76
|
+
* `train` – Train a new model
|
|
77
|
+
* `predict` – Evaluate an existing model
|
|
78
|
+
* `finetune` – Fine-tune a pre-trained model
|
|
79
|
+
|
|
80
|
+
---
|
|
81
|
+
|
|
82
|
+
## Training Models
|
|
83
|
+
|
|
84
|
+
```bash
|
|
85
|
+
gridfm_graphkit train --config path/to/config.yaml
|
|
86
|
+
```
|
|
87
|
+
|
|
88
|
+
### Arguments
|
|
89
|
+
|
|
90
|
+
| Argument | Type | Description | Default |
|
|
91
|
+
| ---------------- | ------ | ---------------------------------------------------------------- | ------- |
|
|
92
|
+
| `--config` | `str` | **Required for standard training**. Path to base config YAML. | `None` |
|
|
93
|
+
| `--grid` | `str` | **Optional**. Path to grid search YAML. Not supported with `-c`. | `None` |
|
|
94
|
+
| `--exp` | `str` | **Optional**. MLflow experiment name. Defaults to a timestamp. | `None` |
|
|
95
|
+
| `--data_path` | `str` | **Optional**. Root dataset directory. | `data` |
|
|
96
|
+
| `-c` | `flag` | **Optional**. Enable checkpoint mode. | `False` |
|
|
97
|
+
| `--model_exp_id` | `str` | **Required if `-c` is used**. MLflow experiment ID. | `None` |
|
|
98
|
+
| `--model_run_id` | `str` | **Required if `-c` is used**. MLflow run ID. | `None` |
|
|
99
|
+
|
|
100
|
+
### Examples
|
|
101
|
+
|
|
102
|
+
**Standard Training:**
|
|
103
|
+
|
|
104
|
+
```bash
|
|
105
|
+
gridfm_graphkit train --config config/train.yaml --exp "run1"
|
|
106
|
+
```
|
|
107
|
+
|
|
108
|
+
**Grid Search Training:**
|
|
109
|
+
|
|
110
|
+
```bash
|
|
111
|
+
gridfm_graphkit train --config config/train.yaml --grid config/grid.yaml
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
**Training from Checkpoint:**
|
|
115
|
+
|
|
116
|
+
```bash
|
|
117
|
+
gridfm_graphkit train -c --model_exp_id 123 --model_run_id abc
|
|
118
|
+
```
|
|
119
|
+
|
|
120
|
+
---
|
|
121
|
+
|
|
122
|
+
## Evaluating Models
|
|
123
|
+
|
|
124
|
+
```bash
|
|
125
|
+
gridfm_graphkit predict --model_path model.pth --config config/eval.yaml --eval_name run_eval
|
|
126
|
+
```
|
|
127
|
+
|
|
128
|
+
### Arguments
|
|
129
|
+
|
|
130
|
+
| Argument | Type | Description | Default |
|
|
131
|
+
| ---------------- | ----- | ----------------------------------------------------------------- | ------------ |
|
|
132
|
+
| `--model_path` | `str` | **Optional**. Path to a model file. | `None` |
|
|
133
|
+
| `--model_exp_id` | `str` | **Required if `--model_path` is not used**. MLflow experiment ID. | `None` |
|
|
134
|
+
| `--model_run_id` | `str` | **Required if `--model_path` is not used**. MLflow run ID. | `None` |
|
|
135
|
+
| `--model_name` | `str` | **Optional**. Filename inside MLflow artifacts. | `best_model` |
|
|
136
|
+
| `--config` | `str` | **Required**. Path to evaluation config. | `None` |
|
|
137
|
+
| `--eval_name` | `str` | **Required**. Name of the evaluation run in MLflow. | `None` |
|
|
138
|
+
| `--data_path` | `str` | **Optional**. Path to dataset directory. | `data` |
|
|
139
|
+
|
|
140
|
+
### Examples
|
|
141
|
+
|
|
142
|
+
**Evaluate a Logged MLflow Model:**
|
|
143
|
+
|
|
144
|
+
```bash
|
|
145
|
+
gridfm_graphkit predict --config config/eval.yaml --eval_name run_eval --model_exp_id 1 --model_run_id abc
|
|
146
|
+
```
|
|
147
|
+
|
|
148
|
+
---
|
|
149
|
+
|
|
150
|
+
## Fine-Tuning Models
|
|
151
|
+
|
|
152
|
+
```bash
|
|
153
|
+
gridfm_graphkit finetune --config path/to/config.yaml --model_path path/to/model.pth
|
|
154
|
+
```
|
|
155
|
+
|
|
156
|
+
### Arguments
|
|
157
|
+
|
|
158
|
+
| Argument | Type | Description | Default |
|
|
159
|
+
| -------------- | ----- | ----------------------------------------------- | ------- |
|
|
160
|
+
| `--config` | `str` | **Required**. Fine-tuning configuration file. | `None` |
|
|
161
|
+
| `--model_path` | `str` | **Required**. Path to a pre-trained model file. | `None` |
|
|
162
|
+
| `--exp` | `str` | **Optional**. MLflow experiment name. | `None` |
|
|
163
|
+
| `--data_path` | `str` | **Optional**. Root dataset directory. | `data` |
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
gridfm_graphkit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
gridfm_graphkit/__main__.py,sha256=nPp_l6YbP62qE907xIrJebLX2DodieZjH79kULAxKgA,2453
|
|
3
|
+
gridfm_graphkit/cli.py,sha256=4j9RN8__EVM2p125k8X7kB507lG_RTUPV_vzPF40Ens,17326
|
|
4
|
+
gridfm_graphkit/datasets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
|
+
gridfm_graphkit/datasets/data_normalization.py,sha256=0nlh5wwFATqqdjpTFADi0De4b4HgkkqRHjyd5dvmRlk,7186
|
|
6
|
+
gridfm_graphkit/datasets/globals.py,sha256=3f0Pcap-_XDJl-uflTpxjnhr1-1p84KzwJYvSiessEo,254
|
|
7
|
+
gridfm_graphkit/datasets/powergrid.py,sha256=nShTY8r7q9WZxLZQjkti0pm9qQheJ9NdD4lScrVeoPY,7391
|
|
8
|
+
gridfm_graphkit/datasets/transforms.py,sha256=p8cLhnl2Ey1RVtwbo6KLzYgQjf1W_RYMBkwfiPpbgSs,7397
|
|
9
|
+
gridfm_graphkit/datasets/utils.py,sha256=ZaXf_AM8dlqURocLBrtD7dn2e5msyQIOHXjF5Da51TY,2395
|
|
10
|
+
gridfm_graphkit/io/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
+
gridfm_graphkit/io/param_handler.py,sha256=s2NIA958lG-5tinYiStDrL4Ua5n9-3B3jeyG8eLjQRM,8998
|
|
12
|
+
gridfm_graphkit/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
|
+
gridfm_graphkit/models/gps_transformer.py,sha256=XyCt-fisnISYauhuknM5Zlx9N-gpby71NpzWEzehjEI,5108
|
|
14
|
+
gridfm_graphkit/models/graphTransformer.py,sha256=aFx_ZXzBvAdXmgdIi7mwT6yw_2h_3t4j22V_xdNcGzk,3364
|
|
15
|
+
gridfm_graphkit/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
|
+
gridfm_graphkit/training/callbacks.py,sha256=5XSEcRRg-Hhuv5vQ8qgvGyC1h_ccLtnoS-fekQzehRw,1550
|
|
17
|
+
gridfm_graphkit/training/plugins.py,sha256=4w0UHlrafYq2WlFIeXbWludRQVUBErqw4dnslKMt8V8,7485
|
|
18
|
+
gridfm_graphkit/training/trainer.py,sha256=q1Y4NOkvMTQT3m75WSU7lTLa3MLiHrA_2AcaqrQ0ce0,5593
|
|
19
|
+
gridfm_graphkit/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
20
|
+
gridfm_graphkit/utils/loss.py,sha256=OyEu5uheMgKuNE7YMoYI-dmquNHopdH0aiF8vSbS4kE,6519
|
|
21
|
+
gridfm_graphkit/utils/visualization.py,sha256=FXAU5nmo6bY_Pwr8X_p42zQSwr-bknJVDABPTeZSnOs,9491
|
|
22
|
+
gridfm_graphkit-0.0.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
23
|
+
gridfm_graphkit-0.0.1.dist-info/METADATA,sha256=a_cgS-Nm-lz2ozouUBRd7dNYV2EByQQk--uxGPzYPhY,5772
|
|
24
|
+
gridfm_graphkit-0.0.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
25
|
+
gridfm_graphkit-0.0.1.dist-info/entry_points.txt,sha256=CVYrtG2_4yKucL63S5klXfmCvfP7l7MBjvfyUuryXcE,66
|
|
26
|
+
gridfm_graphkit-0.0.1.dist-info/top_level.txt,sha256=p3OXCMb-zDtflOl9N4s2Lvp-7hJlmRhf32wE0xbFins,16
|
|
27
|
+
gridfm_graphkit-0.0.1.dist-info/RECORD,,
|