gridfm-graphkit 0.0.1__tar.gz → 0.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/PKG-INFO +16 -8
  2. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/README.md +13 -5
  3. gridfm_graphkit-0.0.2/gridfm_graphkit/evaluation/node_level.py +334 -0
  4. gridfm_graphkit-0.0.2/gridfm_graphkit/utils/__init__.py +0 -0
  5. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/gridfm_graphkit.egg-info/PKG-INFO +16 -8
  6. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/gridfm_graphkit.egg-info/SOURCES.txt +2 -0
  7. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/pyproject.toml +8 -4
  8. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/LICENSE +0 -0
  9. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/gridfm_graphkit/__init__.py +0 -0
  10. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/gridfm_graphkit/__main__.py +0 -0
  11. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/gridfm_graphkit/cli.py +0 -0
  12. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/gridfm_graphkit/datasets/__init__.py +0 -0
  13. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/gridfm_graphkit/datasets/data_normalization.py +0 -0
  14. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/gridfm_graphkit/datasets/globals.py +0 -0
  15. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/gridfm_graphkit/datasets/powergrid.py +0 -0
  16. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/gridfm_graphkit/datasets/transforms.py +0 -0
  17. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/gridfm_graphkit/datasets/utils.py +0 -0
  18. {gridfm_graphkit-0.0.1/gridfm_graphkit/io → gridfm_graphkit-0.0.2/gridfm_graphkit/evaluation}/__init__.py +0 -0
  19. {gridfm_graphkit-0.0.1/gridfm_graphkit/models → gridfm_graphkit-0.0.2/gridfm_graphkit/io}/__init__.py +0 -0
  20. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/gridfm_graphkit/io/param_handler.py +0 -0
  21. {gridfm_graphkit-0.0.1/gridfm_graphkit/training → gridfm_graphkit-0.0.2/gridfm_graphkit/models}/__init__.py +0 -0
  22. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/gridfm_graphkit/models/gps_transformer.py +0 -0
  23. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/gridfm_graphkit/models/graphTransformer.py +0 -0
  24. {gridfm_graphkit-0.0.1/gridfm_graphkit/utils → gridfm_graphkit-0.0.2/gridfm_graphkit/training}/__init__.py +0 -0
  25. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/gridfm_graphkit/training/callbacks.py +0 -0
  26. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/gridfm_graphkit/training/plugins.py +0 -0
  27. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/gridfm_graphkit/training/trainer.py +0 -0
  28. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/gridfm_graphkit/utils/loss.py +0 -0
  29. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/gridfm_graphkit/utils/visualization.py +0 -0
  30. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/gridfm_graphkit.egg-info/dependency_links.txt +0 -0
  31. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/gridfm_graphkit.egg-info/entry_points.txt +0 -0
  32. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/gridfm_graphkit.egg-info/requires.txt +0 -0
  33. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/gridfm_graphkit.egg-info/top_level.txt +0 -0
  34. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/setup.cfg +0 -0
  35. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/tests/test_training.py +0 -0
  36. {gridfm_graphkit-0.0.1 → gridfm_graphkit-0.0.2}/tests/test_yaml_configs.py +0 -0
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gridfm-graphkit
3
- Version: 0.0.1
3
+ Version: 0.0.2
4
4
  Summary: Grid Foundation Model
5
- Author-email: Matteo Mazzonelli <Matteo.Mazzonelli1@ibm.com>, Alban Puech <Alban.Puech1@ibm.com>, Jonas Weiss <jwe@zurich.ibm.com>
6
- Maintainer-email: Matteo Mazzonelli <Matteo.Mazzonelli1@ibm.com>
5
+ Author-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>, Alban Puech <apuech@seas.harvard.edu>, Tamara Govindasamy <tamara.govindasamy@ibm.com>, Mangaliso Mngomezulu <mngomezulum@ibm.com>, Etienne Vos <etienne.vos@ibm.com>, Celia Cintas <celia.cintas@ibm.com>, Jonas Weiss <jwe@zurich.ibm.com>
6
+ Maintainer-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>
7
7
  License-Expression: Apache-2.0
8
8
  Keywords: electric power grid,foundational model,graph neural networks
9
9
  Classifier: Development Status :: 2 - Pre-Alpha
@@ -35,34 +35,42 @@ Requires-Dist: pytest-cov; extra == "test"
35
35
  Dynamic: license-file
36
36
 
37
37
  # gridfm-graphkit
38
- [![Docs](https://img.shields.io/badge/docs-live-brightgreen)](https://gridfm.github.io/gridfm-graphkit/)
38
+ [![Docs](https://img.shields.io/badge/docs-available-brightgreen)](https://gridfm.github.io/gridfm-graphkit/)
39
39
 
40
40
  This library is brought to you by the GridFM team to train, finetune and interact with a foundation model for the electric power grid.
41
41
 
42
42
  ---
43
43
 
44
44
  <p align="center">
45
- <img src="docs/figs/pre_training.png" alt="GridFM logo"/>
45
+ <img src="https://raw.githubusercontent.com/gridfm/gridfm-graphkit/refs/heads/main/docs/figs/pre_training.png" alt="GridFM logo"/>
46
46
  <br/>
47
47
  </p>
48
48
 
49
49
  # Installation
50
50
 
51
- Create a python virtual environment and install the requirements
51
+ You can install `gridfm-graphkit` directly from PyPI:
52
+
53
+ ```bash
54
+ pip install gridfm-graphkit
55
+ ```
56
+
57
+ To contribute or develop locally, clone the repository and install in editable mode:
58
+
52
59
  ```bash
53
60
  git clone git@github.com:gridfm/gridfm-graphkit.git
54
61
  cd gridfm-graphkit
55
62
  python -m venv venv
56
63
  source venv/bin/activate
57
- pip install .
64
+ pip install -e .
58
65
  ```
59
66
 
60
- Install the package in editable mode during development phase:
67
+ For documentation generation and unit testing, install with the optional `dev` and `test` extras:
61
68
 
62
69
  ```bash
63
70
  pip install -e .[dev,test]
64
71
  ```
65
72
 
73
+
66
74
  # gridfm-graphkit CLI
67
75
 
68
76
  An interface to train, fine-tune, and evaluate GridFM models using configurable YAML files and MLflow tracking.
@@ -1,32 +1,40 @@
1
1
  # gridfm-graphkit
2
- [![Docs](https://img.shields.io/badge/docs-live-brightgreen)](https://gridfm.github.io/gridfm-graphkit/)
2
+ [![Docs](https://img.shields.io/badge/docs-available-brightgreen)](https://gridfm.github.io/gridfm-graphkit/)
3
3
 
4
4
  This library is brought to you by the GridFM team to train, finetune and interact with a foundation model for the electric power grid.
5
5
 
6
6
  ---
7
7
 
8
8
  <p align="center">
9
- <img src="docs/figs/pre_training.png" alt="GridFM logo"/>
9
+ <img src="https://raw.githubusercontent.com/gridfm/gridfm-graphkit/refs/heads/main/docs/figs/pre_training.png" alt="GridFM logo"/>
10
10
  <br/>
11
11
  </p>
12
12
 
13
13
  # Installation
14
14
 
15
- Create a python virtual environment and install the requirements
15
+ You can install `gridfm-graphkit` directly from PyPI:
16
+
17
+ ```bash
18
+ pip install gridfm-graphkit
19
+ ```
20
+
21
+ To contribute or develop locally, clone the repository and install in editable mode:
22
+
16
23
  ```bash
17
24
  git clone git@github.com:gridfm/gridfm-graphkit.git
18
25
  cd gridfm-graphkit
19
26
  python -m venv venv
20
27
  source venv/bin/activate
21
- pip install .
28
+ pip install -e .
22
29
  ```
23
30
 
24
- Install the package in editable mode during development phase:
31
+ For documentation generation and unit testing, install with the optional `dev` and `test` extras:
25
32
 
26
33
  ```bash
27
34
  pip install -e .[dev,test]
28
35
  ```
29
36
 
37
+
30
38
  # gridfm-graphkit CLI
31
39
 
32
40
  An interface to train, fine-tune, and evaluate GridFM models using configurable YAML files and MLflow tracking.
@@ -0,0 +1,334 @@
1
+ from gridfm_graphkit.datasets.globals import BUS_TYPES, FEATURES_IDX, PQ, PV, REF
2
+ from gridfm_graphkit.datasets.data_normalization import BaseMVANormalizer
3
+ from gridfm_graphkit.datasets.transforms import AddRandomMask, AddPFMask, AddOPFMask
4
+ from gridfm_graphkit.utils.loss import PBELoss
5
+
6
+ import torch
7
+ import numpy as np
8
+ import pandas as pd
9
+ from typing import List, Tuple
10
+ from torch.utils.data import DataLoader
11
+ import plotly.graph_objects as go
12
+ from torch_geometric.data import Dataset
13
+
14
+
15
+ def get_dist_plot(
16
+ data: np.ndarray,
17
+ data_type: str,
18
+ bus_types: List[str],
19
+ n_buses: int,
20
+ ) -> go.Figure:
21
+ """
22
+ Generates distribution plots for the different feature and for each bus.
23
+
24
+ Args:
25
+ data (np.ndarray): The input data matrix, e.g. residuals or model outputs, of shape (n_buses x len(test_dataset), n_features)
26
+ data_type (str): The type of data being plotted (e.g., 'residuals', 'model outputs').
27
+ bus_types (List[str]): List of bus types for each bus in the graphs
28
+ n_buses (int): The total number of buses in the grid.
29
+
30
+ Returns:
31
+ List[go.Figure]: List of Plotly figures, each representing box plots of the distribution of one feature for each of the buses
32
+ """
33
+
34
+ figs = []
35
+
36
+ for feature, feature_idx in FEATURES_IDX.items():
37
+ fig = go.Figure()
38
+
39
+ for bus_idx in range(n_buses):
40
+ # Add box plot of distribution of feature for each bus
41
+ fig.add_trace(
42
+ go.Box(
43
+ y=data[
44
+ bus_idx::n_buses,
45
+ feature_idx,
46
+ ], # Slice data for each bus (!!)
47
+ name=f"Bus {bus_idx} ({bus_types[bus_idx]})",
48
+ ),
49
+ )
50
+
51
+ fig.update_layout(
52
+ title="{} {} distribution".format(feature, data_type),
53
+ xaxis_title="Bus Number",
54
+ yaxis_title="{}".format(data_type),
55
+ showlegend=True,
56
+ )
57
+ figs.append(fig)
58
+ return figs
59
+
60
+
61
+ def training_stats_to_dataframe(
62
+ rmse_PQ: np.ndarray,
63
+ rmse_PV: np.ndarray,
64
+ rmse_REF: np.ndarray,
65
+ mae_PQ: np.ndarray,
66
+ mae_PV: np.ndarray,
67
+ mae_REF: np.ndarray,
68
+ overall_RMSE: np.ndarray,
69
+ overall_MAE: np.ndarray,
70
+ overall_active_loss: float,
71
+ overall_reactive_loss: float,
72
+ ) -> pd.DataFrame:
73
+ """
74
+ Converts training statistics into a pandas DataFrame.
75
+
76
+ Args:
77
+ RMSE_loss_PQ (np.ndarray): RMSE losses for each feature at PQ nodes
78
+ RMSE_loss_PV (np.ndarray): RMSE losses for each feature at PV nodes.
79
+ RMSE_loss_REF (np.ndarray): RMSE losses for each feature at REF nodes.
80
+ MAE_loss_PQ (np.ndarray): MAE losses for each feature at PQ nodes.
81
+ MAE_loss_PV (np.ndarray): MAE losses for each feature at PV nodes.
82
+ MAE_loss_REF (np.ndarray): MAE losses for each feature at REF nodes.
83
+ overall_active_loss (float): Mean active power loss across nodes
84
+ overall_reactive_loss (float): Mean reactive power loss across nodes
85
+
86
+ Returns:
87
+ pd.DataFrame: DataFrame containing aggregated statistics.
88
+ """
89
+
90
+ data = {
91
+ "Metric": [
92
+ "RMSE-PQ",
93
+ "RMSE-PV",
94
+ "RMSE-REF",
95
+ "MAE-PQ",
96
+ "MAE-PV",
97
+ "MAE-REF",
98
+ "Overall RMSE",
99
+ "Overall MAE",
100
+ "Avg. active res. (MW)",
101
+ "Avg. reactive res. (MVar)",
102
+ ],
103
+ "Pd (MW)": [
104
+ rmse_PQ[0],
105
+ rmse_PV[0],
106
+ rmse_REF[0],
107
+ mae_PQ[0],
108
+ mae_PV[0],
109
+ mae_REF[0],
110
+ overall_RMSE[0],
111
+ overall_MAE[0],
112
+ overall_active_loss,
113
+ overall_reactive_loss,
114
+ ],
115
+ "Qd (MVar)": [
116
+ rmse_PQ[1],
117
+ rmse_PV[1],
118
+ rmse_REF[1],
119
+ mae_PQ[1],
120
+ mae_PV[1],
121
+ mae_REF[1],
122
+ overall_RMSE[1],
123
+ overall_MAE[1],
124
+ " ",
125
+ " ",
126
+ ],
127
+ "Pg (MW)": [
128
+ rmse_PQ[2],
129
+ rmse_PV[2],
130
+ rmse_REF[2],
131
+ mae_PQ[2],
132
+ mae_PV[2],
133
+ mae_REF[2],
134
+ overall_RMSE[2],
135
+ overall_MAE[2],
136
+ " ",
137
+ " ",
138
+ ],
139
+ "Qg (MVar)": [
140
+ rmse_PQ[3],
141
+ rmse_PV[3],
142
+ rmse_REF[3],
143
+ mae_PQ[3],
144
+ mae_PV[3],
145
+ mae_REF[3],
146
+ overall_RMSE[3],
147
+ overall_MAE[3],
148
+ " ",
149
+ " ",
150
+ ],
151
+ "Vm (p.u.)": [
152
+ rmse_PQ[4],
153
+ rmse_PV[4],
154
+ rmse_REF[4],
155
+ mae_PQ[4],
156
+ mae_PV[4],
157
+ mae_REF[4],
158
+ overall_RMSE[4],
159
+ overall_MAE[4],
160
+ " ",
161
+ " ",
162
+ ],
163
+ "Va (degree)": [
164
+ rmse_PQ[5],
165
+ rmse_PV[5],
166
+ rmse_REF[5],
167
+ mae_PQ[5],
168
+ mae_PV[5],
169
+ mae_REF[5],
170
+ overall_RMSE[5],
171
+ overall_MAE[5],
172
+ " ",
173
+ " ",
174
+ ],
175
+ }
176
+ return pd.DataFrame(data)
177
+
178
+
179
+ def eval_node_level_task(
180
+ dataset: Dataset,
181
+ model: torch.nn.Module,
182
+ task: str,
183
+ test_loader: DataLoader,
184
+ mask_dim: int,
185
+ mask_ratio: float,
186
+ node_normalizer: object,
187
+ device: torch.device,
188
+ plot_dist: bool = True,
189
+ ) -> Tuple[pd.DataFrame, List[go.Figure]]:
190
+ """
191
+ Evaluates the model and computes per-feature statistics.
192
+
193
+ Args:
194
+ model (torch.nn.Module): The trained model.
195
+ task: task to evaluate the model on e.g. PF
196
+ test_loader (DataLoader): DataLoader for test data.
197
+ mask_dim (int): number of masked features
198
+ node_normalizer (object): Normalizer for input/output features.
199
+ plot_dist (bool): Whether to generate distribution plots.
200
+ device (torch.device): Device to run the evaluation on.
201
+
202
+ Returns:
203
+ Tuple[pd.DataFrame, List[px.Figure]]: DataFrame with evaluation metrics and plotly figure.
204
+ """
205
+ model.eval()
206
+
207
+ # Initialize lists to collect outputs and targets
208
+ all_outputs = []
209
+ all_targets = []
210
+ all_mask_PQ = []
211
+ all_mask_PV = []
212
+ all_mask_REF = []
213
+ all_active_loss = []
214
+ all_reactive_loss = []
215
+
216
+ loss_PBE = PBELoss()
217
+
218
+ # Mask input features
219
+ if task == "PF":
220
+ dataset.change_transform(AddPFMask())
221
+ elif task == "OPF":
222
+ dataset.change_transform(AddOPFMask())
223
+ elif task == "Reconstruction":
224
+ dataset.change_transform(
225
+ AddRandomMask(mask_dim=mask_dim, mask_ratio=mask_ratio),
226
+ )
227
+ else:
228
+ raise ValueError(f"Unknown task: {task}")
229
+
230
+ with torch.no_grad():
231
+ for batch in test_loader:
232
+ batch = batch.to(device)
233
+
234
+ mask_PQ = batch.x[:, PQ] == 1
235
+ mask_PV = batch.x[:, PV] == 1
236
+ mask_REF = batch.x[:, REF] == 1
237
+
238
+ mask_value_expanded = model.mask_value.expand(batch.x.shape[0], -1)
239
+ batch.x[:, : batch.mask.shape[1]][batch.mask] = mask_value_expanded[
240
+ batch.mask
241
+ ]
242
+
243
+ # Forward pass
244
+ output = model(
245
+ batch.x,
246
+ batch.pe,
247
+ batch.edge_index,
248
+ batch.edge_attr,
249
+ batch.batch,
250
+ )
251
+
252
+ if isinstance(node_normalizer, BaseMVANormalizer):
253
+ loss_PBE_dict = loss_PBE(
254
+ output,
255
+ batch.y,
256
+ batch.edge_index,
257
+ batch.edge_attr,
258
+ batch.mask,
259
+ )
260
+ all_active_loss.append(
261
+ loss_PBE_dict["Active Power Loss in p.u."]
262
+ * node_normalizer.baseMVA,
263
+ )
264
+ all_reactive_loss.append(
265
+ loss_PBE_dict["Reactive Power Loss in p.u."]
266
+ * node_normalizer.baseMVA,
267
+ )
268
+ else:
269
+ all_active_loss.append(-1.0)
270
+ all_reactive_loss.append(-1.0)
271
+
272
+ # Denormalize
273
+ output_denorm = node_normalizer.inverse_transform(output)
274
+ target_denorm = node_normalizer.inverse_transform(batch.y)
275
+
276
+ # Collect outputs, targets, and masks
277
+ all_outputs.append(output_denorm)
278
+ all_targets.append(target_denorm)
279
+ all_mask_PQ.append(mask_PQ)
280
+ all_mask_PV.append(mask_PV)
281
+ all_mask_REF.append(mask_REF)
282
+
283
+ n_buses = int((batch.batch == 0).sum()) # Number of buses in graph
284
+ bus_types = [
285
+ BUS_TYPES[np.argmax(row[mask_dim:])] for row in batch.x[:n_buses].cpu()
286
+ ] # Ugly hack to get bus types from input features
287
+
288
+ # Concatenate all outputs, targets, and masks
289
+ all_outputs = torch.cat(all_outputs, dim=0).cpu()
290
+ all_targets = torch.cat(all_targets, dim=0).cpu()
291
+ all_mask_PQ = torch.cat(all_mask_PQ, dim=0).cpu()
292
+ all_mask_PV = torch.cat(all_mask_PV, dim=0).cpu()
293
+ all_mask_REF = torch.cat(all_mask_REF, dim=0).cpu()
294
+
295
+ # Compute per-feature RMSE and MAE after collecting all batches
296
+ residuals = (all_outputs - all_targets).numpy()
297
+ squared_residuals = residuals**2
298
+ absolute_residuals = np.abs(residuals)
299
+
300
+ rmse_PQ = np.sqrt(np.mean(squared_residuals[all_mask_PQ], axis=0))
301
+ rmse_PV = np.sqrt(np.mean(squared_residuals[all_mask_PV], axis=0))
302
+ rmse_REF = np.sqrt(np.mean(squared_residuals[all_mask_REF], axis=0))
303
+
304
+ mae_PQ = np.mean(absolute_residuals[all_mask_PQ], axis=0)
305
+ mae_PV = np.mean(absolute_residuals[all_mask_PV], axis=0)
306
+ mae_REF = np.mean(absolute_residuals[all_mask_REF], axis=0)
307
+
308
+ overall_RMSE = np.sqrt(np.mean(squared_residuals, axis=0))
309
+ overall_MAE = np.mean(absolute_residuals, axis=0)
310
+
311
+ overall_active_loss = np.mean(all_active_loss)
312
+ overall_reactive_loss = np.mean(all_reactive_loss)
313
+
314
+ figs = []
315
+
316
+ if plot_dist:
317
+ figs.extend(get_dist_plot(residuals, "residuals", bus_types, n_buses))
318
+ figs.extend(get_dist_plot(all_outputs, "model outputs", bus_types, n_buses))
319
+
320
+ # Convert statistics to a DataFrame
321
+ df = training_stats_to_dataframe(
322
+ rmse_PQ,
323
+ rmse_PV,
324
+ rmse_REF,
325
+ mae_PQ,
326
+ mae_PV,
327
+ mae_REF,
328
+ overall_RMSE,
329
+ overall_MAE,
330
+ overall_active_loss,
331
+ overall_reactive_loss,
332
+ )
333
+ dataset.reset_transform()
334
+ return df, figs
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gridfm-graphkit
3
- Version: 0.0.1
3
+ Version: 0.0.2
4
4
  Summary: Grid Foundation Model
5
- Author-email: Matteo Mazzonelli <Matteo.Mazzonelli1@ibm.com>, Alban Puech <Alban.Puech1@ibm.com>, Jonas Weiss <jwe@zurich.ibm.com>
6
- Maintainer-email: Matteo Mazzonelli <Matteo.Mazzonelli1@ibm.com>
5
+ Author-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>, Alban Puech <apuech@seas.harvard.edu>, Tamara Govindasamy <tamara.govindasamy@ibm.com>, Mangaliso Mngomezulu <mngomezulum@ibm.com>, Etienne Vos <etienne.vos@ibm.com>, Celia Cintas <celia.cintas@ibm.com>, Jonas Weiss <jwe@zurich.ibm.com>
6
+ Maintainer-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>
7
7
  License-Expression: Apache-2.0
8
8
  Keywords: electric power grid,foundational model,graph neural networks
9
9
  Classifier: Development Status :: 2 - Pre-Alpha
@@ -35,34 +35,42 @@ Requires-Dist: pytest-cov; extra == "test"
35
35
  Dynamic: license-file
36
36
 
37
37
  # gridfm-graphkit
38
- [![Docs](https://img.shields.io/badge/docs-live-brightgreen)](https://gridfm.github.io/gridfm-graphkit/)
38
+ [![Docs](https://img.shields.io/badge/docs-available-brightgreen)](https://gridfm.github.io/gridfm-graphkit/)
39
39
 
40
40
  This library is brought to you by the GridFM team to train, finetune and interact with a foundation model for the electric power grid.
41
41
 
42
42
  ---
43
43
 
44
44
  <p align="center">
45
- <img src="docs/figs/pre_training.png" alt="GridFM logo"/>
45
+ <img src="https://raw.githubusercontent.com/gridfm/gridfm-graphkit/refs/heads/main/docs/figs/pre_training.png" alt="GridFM logo"/>
46
46
  <br/>
47
47
  </p>
48
48
 
49
49
  # Installation
50
50
 
51
- Create a python virtual environment and install the requirements
51
+ You can install `gridfm-graphkit` directly from PyPI:
52
+
53
+ ```bash
54
+ pip install gridfm-graphkit
55
+ ```
56
+
57
+ To contribute or develop locally, clone the repository and install in editable mode:
58
+
52
59
  ```bash
53
60
  git clone git@github.com:gridfm/gridfm-graphkit.git
54
61
  cd gridfm-graphkit
55
62
  python -m venv venv
56
63
  source venv/bin/activate
57
- pip install .
64
+ pip install -e .
58
65
  ```
59
66
 
60
- Install the package in editable mode during development phase:
67
+ For documentation generation and unit testing, install with the optional `dev` and `test` extras:
61
68
 
62
69
  ```bash
63
70
  pip install -e .[dev,test]
64
71
  ```
65
72
 
73
+
66
74
  # gridfm-graphkit CLI
67
75
 
68
76
  An interface to train, fine-tune, and evaluate GridFM models using configurable YAML files and MLflow tracking.
@@ -16,6 +16,8 @@ gridfm_graphkit/datasets/globals.py
16
16
  gridfm_graphkit/datasets/powergrid.py
17
17
  gridfm_graphkit/datasets/transforms.py
18
18
  gridfm_graphkit/datasets/utils.py
19
+ gridfm_graphkit/evaluation/__init__.py
20
+ gridfm_graphkit/evaluation/node_level.py
19
21
  gridfm_graphkit/io/__init__.py
20
22
  gridfm_graphkit/io/param_handler.py
21
23
  gridfm_graphkit/models/__init__.py
@@ -9,19 +9,23 @@ namespaces = false
9
9
  [project]
10
10
  name = "gridfm-graphkit"
11
11
  description = "Grid Foundation Model"
12
- version = "0.0.1"
12
+ version = "0.0.2"
13
13
  readme = "README.md"
14
14
  license = "Apache-2.0"
15
15
  requires-python = ">=3.12.10"
16
16
 
17
17
  authors = [
18
- {name = "Matteo Mazzonelli", email = "Matteo.Mazzonelli1@ibm.com"},
19
- {name = "Alban Puech", email = "Alban.Puech1@ibm.com"},
18
+ {name = "Matteo Mazzonelli", email = "matteo.mazzonelli1@ibm.com"},
19
+ {name = "Alban Puech", email = "apuech@seas.harvard.edu"},
20
+ {name = "Tamara Govindasamy", email= "tamara.govindasamy@ibm.com"},
21
+ {name = "Mangaliso Mngomezulu", email= "mngomezulum@ibm.com"},
22
+ {name = "Etienne Vos", email= "etienne.vos@ibm.com"},
23
+ {name = "Celia Cintas", email= "celia.cintas@ibm.com"},
20
24
  {name = "Jonas Weiss", email= "jwe@zurich.ibm.com"},
21
25
  ]
22
26
 
23
27
  maintainers = [
24
- {name = "Matteo Mazzonelli", email = "Matteo.Mazzonelli1@ibm.com"},
28
+ {name = "Matteo Mazzonelli", email = "matteo.mazzonelli1@ibm.com"},
25
29
  ]
26
30
 
27
31
  keywords = ["electric power grid", "foundational model", "graph neural networks"]
File without changes