gridfm-graphkit 0.0.2a0__tar.gz → 0.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. gridfm_graphkit-0.0.4/PKG-INFO +180 -0
  2. gridfm_graphkit-0.0.4/README.md +141 -0
  3. gridfm_graphkit-0.0.4/gridfm_graphkit/__init__.py +6 -0
  4. gridfm_graphkit-0.0.4/gridfm_graphkit/__main__.py +58 -0
  5. gridfm_graphkit-0.0.4/gridfm_graphkit/cli.py +134 -0
  6. gridfm_graphkit-0.0.4/gridfm_graphkit/datasets/__init__.py +23 -0
  7. gridfm_graphkit-0.0.2a0/gridfm_graphkit/datasets/data_normalization.py → gridfm_graphkit-0.0.4/gridfm_graphkit/datasets/normalizers.py +58 -6
  8. gridfm_graphkit-0.0.4/gridfm_graphkit/datasets/powergrid_datamodule.py +207 -0
  9. gridfm_graphkit-0.0.2a0/gridfm_graphkit/datasets/powergrid.py → gridfm_graphkit-0.0.4/gridfm_graphkit/datasets/powergrid_dataset.py +57 -22
  10. {gridfm_graphkit-0.0.2a0 → gridfm_graphkit-0.0.4}/gridfm_graphkit/datasets/transforms.py +17 -3
  11. {gridfm_graphkit-0.0.2a0 → gridfm_graphkit-0.0.4}/gridfm_graphkit/datasets/utils.py +0 -16
  12. gridfm_graphkit-0.0.4/gridfm_graphkit/io/param_handler.py +138 -0
  13. gridfm_graphkit-0.0.4/gridfm_graphkit/io/registries.py +42 -0
  14. gridfm_graphkit-0.0.4/gridfm_graphkit/models/__init__.py +4 -0
  15. gridfm_graphkit-0.0.2a0/gridfm_graphkit/models/graphTransformer.py → gridfm_graphkit-0.0.4/gridfm_graphkit/models/gnn_transformer.py +34 -34
  16. {gridfm_graphkit-0.0.2a0 → gridfm_graphkit-0.0.4}/gridfm_graphkit/models/gps_transformer.py +40 -43
  17. gridfm_graphkit-0.0.4/gridfm_graphkit/tasks/feature_reconstruction_task.py +366 -0
  18. gridfm_graphkit-0.0.4/gridfm_graphkit/training/callbacks.py +49 -0
  19. {gridfm_graphkit-0.0.2a0/gridfm_graphkit/utils → gridfm_graphkit-0.0.4/gridfm_graphkit/training}/loss.py +9 -9
  20. gridfm_graphkit-0.0.4/gridfm_graphkit/utils/visualization.py +99 -0
  21. gridfm_graphkit-0.0.4/gridfm_graphkit.egg-info/PKG-INFO +180 -0
  22. {gridfm_graphkit-0.0.2a0 → gridfm_graphkit-0.0.4}/gridfm_graphkit.egg-info/SOURCES.txt +13 -9
  23. {gridfm_graphkit-0.0.2a0 → gridfm_graphkit-0.0.4}/gridfm_graphkit.egg-info/requires.txt +1 -0
  24. {gridfm_graphkit-0.0.2a0 → gridfm_graphkit-0.0.4}/pyproject.toml +5 -2
  25. gridfm_graphkit-0.0.4/tests/test_data_module.py +60 -0
  26. gridfm_graphkit-0.0.4/tests/test_full_pipeline.py +82 -0
  27. gridfm_graphkit-0.0.4/tests/test_losses.py +36 -0
  28. gridfm_graphkit-0.0.4/tests/test_model_outputs.py +69 -0
  29. gridfm_graphkit-0.0.4/tests/test_normalization.py +36 -0
  30. {gridfm_graphkit-0.0.2a0 → gridfm_graphkit-0.0.4}/tests/test_yaml_configs.py +23 -0
  31. gridfm_graphkit-0.0.2a0/PKG-INFO +0 -163
  32. gridfm_graphkit-0.0.2a0/README.md +0 -127
  33. gridfm_graphkit-0.0.2a0/gridfm_graphkit/__main__.py +0 -62
  34. gridfm_graphkit-0.0.2a0/gridfm_graphkit/cli.py +0 -530
  35. gridfm_graphkit-0.0.2a0/gridfm_graphkit/evaluation/node_level.py +0 -334
  36. gridfm_graphkit-0.0.2a0/gridfm_graphkit/io/param_handler.py +0 -293
  37. gridfm_graphkit-0.0.2a0/gridfm_graphkit/models/__init__.py +0 -0
  38. gridfm_graphkit-0.0.2a0/gridfm_graphkit/training/__init__.py +0 -0
  39. gridfm_graphkit-0.0.2a0/gridfm_graphkit/training/callbacks.py +0 -47
  40. gridfm_graphkit-0.0.2a0/gridfm_graphkit/training/plugins.py +0 -218
  41. gridfm_graphkit-0.0.2a0/gridfm_graphkit/training/trainer.py +0 -156
  42. gridfm_graphkit-0.0.2a0/gridfm_graphkit/utils/__init__.py +0 -0
  43. gridfm_graphkit-0.0.2a0/gridfm_graphkit/utils/visualization.py +0 -324
  44. gridfm_graphkit-0.0.2a0/gridfm_graphkit.egg-info/PKG-INFO +0 -163
  45. gridfm_graphkit-0.0.2a0/tests/test_training.py +0 -90
  46. {gridfm_graphkit-0.0.2a0 → gridfm_graphkit-0.0.4}/LICENSE +0 -0
  47. {gridfm_graphkit-0.0.2a0 → gridfm_graphkit-0.0.4}/gridfm_graphkit/datasets/globals.py +0 -0
  48. {gridfm_graphkit-0.0.2a0/gridfm_graphkit → gridfm_graphkit-0.0.4/gridfm_graphkit/io}/__init__.py +0 -0
  49. {gridfm_graphkit-0.0.2a0/gridfm_graphkit/datasets → gridfm_graphkit-0.0.4/gridfm_graphkit/tasks}/__init__.py +0 -0
  50. {gridfm_graphkit-0.0.2a0/gridfm_graphkit/evaluation → gridfm_graphkit-0.0.4/gridfm_graphkit/training}/__init__.py +0 -0
  51. {gridfm_graphkit-0.0.2a0/gridfm_graphkit/io → gridfm_graphkit-0.0.4/gridfm_graphkit/utils}/__init__.py +0 -0
  52. {gridfm_graphkit-0.0.2a0 → gridfm_graphkit-0.0.4}/gridfm_graphkit.egg-info/dependency_links.txt +0 -0
  53. {gridfm_graphkit-0.0.2a0 → gridfm_graphkit-0.0.4}/gridfm_graphkit.egg-info/entry_points.txt +0 -0
  54. {gridfm_graphkit-0.0.2a0 → gridfm_graphkit-0.0.4}/gridfm_graphkit.egg-info/top_level.txt +0 -0
  55. {gridfm_graphkit-0.0.2a0 → gridfm_graphkit-0.0.4}/setup.cfg +0 -0
@@ -0,0 +1,180 @@
1
+ Metadata-Version: 2.4
2
+ Name: gridfm-graphkit
3
+ Version: 0.0.4
4
+ Summary: Grid Foundation Model
5
+ Author-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>, Alban Puech <apuech@seas.harvard.edu>, Tamara Govindasamy <tamara.govindasamy@ibm.com>, Mangaliso Mngomezulu <mngomezulum@ibm.com>, Etienne Vos <etienne.vos@ibm.com>, Celia Cintas <celia.cintas@ibm.com>, Jonas Weiss <jwe@zurich.ibm.com>
6
+ Maintainer-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>
7
+ License-Expression: Apache-2.0
8
+ Keywords: electric power grid,foundational model,graph neural networks
9
+ Classifier: Development Status :: 2 - Pre-Alpha
10
+ Classifier: Programming Language :: Python :: 3.10
11
+ Classifier: Programming Language :: Python :: 3.11
12
+ Classifier: Programming Language :: Python :: 3.12
13
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
+ Requires-Python: <3.13,>=3.10
15
+ Description-Content-Type: text/markdown
16
+ License-File: LICENSE
17
+ Requires-Dist: mlflow>=3.1.0
18
+ Requires-Dist: nbformat>=5.10.4
19
+ Requires-Dist: networkx>=3.4.2
20
+ Requires-Dist: numpy>=2.2.6
21
+ Requires-Dist: pandas>=2.3.0
22
+ Requires-Dist: plotly>=6.1.2
23
+ Requires-Dist: pyyaml>=6.0.2
24
+ Requires-Dist: torch>=2.7.1
25
+ Requires-Dist: torch-geometric>=2.6.1
26
+ Requires-Dist: torchaudio>=2.7.1
27
+ Requires-Dist: torchvision>=0.22.1
28
+ Requires-Dist: lightning
29
+ Provides-Extra: dev
30
+ Requires-Dist: mkdocs-material; extra == "dev"
31
+ Requires-Dist: mkdocstrings[python]; extra == "dev"
32
+ Requires-Dist: pre-commit>=4.2.0; extra == "dev"
33
+ Requires-Dist: bandit>=1.8.5; extra == "dev"
34
+ Requires-Dist: build; extra == "dev"
35
+ Provides-Extra: test
36
+ Requires-Dist: pytest; extra == "test"
37
+ Requires-Dist: pytest-cov; extra == "test"
38
+ Dynamic: license-file
39
+
40
+ # gridfm-graphkit
41
+ [![Docs](https://img.shields.io/badge/docs-available-brightgreen)](https://gridfm.github.io/gridfm-graphkit/)
42
+ ![Coverage](https://img.shields.io/badge/coverage-83%25-yellowgreen)
43
+ ![Python](https://img.shields.io/badge/python-3.10%20%E2%80%93%203.12-blue)
44
+ ![License](https://img.shields.io/badge/license-Apache%202.0-blue)
45
+
46
+ This library is brought to you by the GridFM team to train, finetune and interact with a foundation model for the electric power grid.
47
+
48
+ ---
49
+
50
+ <p align="center">
51
+ <img src="https://raw.githubusercontent.com/gridfm/gridfm-graphkit/refs/heads/main/docs/figs/pre_training.png" alt="GridFM logo"/>
52
+ <br/>
53
+ </p>
54
+
55
+ # Installation
56
+
57
+ You can install `gridfm-graphkit` directly from PyPI:
58
+
59
+ ```bash
60
+ pip install gridfm-graphkit
61
+ ```
62
+
63
+ To contribute or develop locally, clone the repository and install in editable mode:
64
+
65
+ ```bash
66
+ git clone git@github.com:gridfm/gridfm-graphkit.git
67
+ cd gridfm-graphkit
68
+ python -m venv venv
69
+ source venv/bin/activate
70
+ pip install -e .
71
+ ```
72
+
73
+ For documentation generation and unit testing, install with the optional `dev` and `test` extras:
74
+
75
+ ```bash
76
+ pip install -e .[dev,test]
77
+ ```
78
+
79
+
80
+ # CLI commands
81
+
82
+ An interface to train, fine-tune, and evaluate GridFM models using configurable YAML files and MLflow tracking.
83
+
84
+ ```bash
85
+ gridfm_graphkit <command> [OPTIONS]
86
+ ```
87
+
88
+ Available commands:
89
+
90
+ * `train` – Train a new model from scrathc
91
+ * `finetune` – Fine-tune an existing pre-trained model
92
+ * `evaluate` – Evaluate model performance on a dataset
93
+ * `predict` – Run inference and save predictions
94
+
95
+ ---
96
+
97
+ ## Training Models
98
+
99
+ ```bash
100
+ gridfm_graphkit train --config path/to/config.yaml
101
+ ```
102
+
103
+ ### Arguments
104
+
105
+ | Argument | Type | Description | Default |
106
+ | ---------------- | ------ | ---------------------------------------------------------------- | ------- |
107
+ | `--config` | `str` | **Required**. Path to the training configuration YAML file. | `None` |
108
+ | `--exp_name` | `str` | **Optional**. MLflow experiment name. | `timestamp` |
109
+ | `--run_name` | `str` | **Optional**. MLflow run name. | `run` |
110
+ | `--log_dir ` | `str` | **Optional**. MLflow logging directory. | `mlruns` |
111
+ | `--data_path` | `str` | **Optional**. Root dataset directory. | `data` |
112
+
113
+ ### Examples
114
+
115
+ **Standard Training:**
116
+
117
+ ```bash
118
+ gridfm_graphkit train --config examples/config/case30_ieee_base.yaml --data_path examples/data
119
+ ```
120
+
121
+ ---
122
+
123
+ ## Fine-Tuning Models
124
+
125
+ ```bash
126
+ gridfm_graphkit finetune --config path/to/config.yaml --model_path path/to/model.pth
127
+ ```
128
+
129
+ ### Arguments
130
+
131
+ | Argument | Type | Description | Default |
132
+ | -------------- | ----- | ----------------------------------------------- | --------- |
133
+ | `--config` | `str` | **Required**. Fine-tuning configuration file. | `None` |
134
+ | `--model_path` | `str` | **Required**. Path to a pre-trained model file. | `None` |
135
+ | `--exp_name` | `str` | MLflow experiment name. | timestamp |
136
+ | `--run_name` | `str` | MLflow run name. | `run` |
137
+ | `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
138
+ | `--data_path` | `str` | Root dataset directory. | `data` |
139
+
140
+
141
+ ---
142
+
143
+ ## Evaluating Models
144
+
145
+ ```bash
146
+ gridfm_graphkit evaluate --config path/to/eval.yaml --model_path path/to/model.pth
147
+ ```
148
+
149
+ ### Arguments
150
+
151
+ | Argument | Type | Description | Default |
152
+ | -------------- | ----- | ---------------------------------------- | --------- |
153
+ | `--config` | `str` | **Required**. Path to evaluation config. | `None` |
154
+ | `--model_path` | `str` | Path to the trained model file. | `None` |
155
+ | `--exp_name` | `str` | MLflow experiment name. | timestamp |
156
+ | `--run_name` | `str` | MLflow run name. | `run` |
157
+ | `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
158
+ | `--data_path` | `str` | Dataset directory. | `data` |
159
+
160
+ ---
161
+
162
+ ## Running Predictions
163
+
164
+ ```bash
165
+ gridfm_graphkit predict --config path/to/config.yaml --model_path path/to/model.pth
166
+ ```
167
+
168
+ ### Arguments
169
+
170
+ | Argument | Type | Description | Default |
171
+ | --------------- | ----- | --------------------------------------------- | --------- |
172
+ | `--config` | `str` | **Required**. Path to prediction config file. | `None` |
173
+ | `--model_path` | `str` | Path to the trained model file. | `None` |
174
+ | `--exp_name` | `str` | MLflow experiment name. | timestamp |
175
+ | `--run_name` | `str` | MLflow run name. | `run` |
176
+ | `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
177
+ | `--data_path` | `str` | Dataset directory. | `data` |
178
+ | `--output_path` | `str` | Directory where predictions are saved. | `data` |
179
+
180
+ ---
@@ -0,0 +1,141 @@
1
+ # gridfm-graphkit
2
+ [![Docs](https://img.shields.io/badge/docs-available-brightgreen)](https://gridfm.github.io/gridfm-graphkit/)
3
+ ![Coverage](https://img.shields.io/badge/coverage-83%25-yellowgreen)
4
+ ![Python](https://img.shields.io/badge/python-3.10%20%E2%80%93%203.12-blue)
5
+ ![License](https://img.shields.io/badge/license-Apache%202.0-blue)
6
+
7
+ This library is brought to you by the GridFM team to train, finetune and interact with a foundation model for the electric power grid.
8
+
9
+ ---
10
+
11
+ <p align="center">
12
+ <img src="https://raw.githubusercontent.com/gridfm/gridfm-graphkit/refs/heads/main/docs/figs/pre_training.png" alt="GridFM logo"/>
13
+ <br/>
14
+ </p>
15
+
16
+ # Installation
17
+
18
+ You can install `gridfm-graphkit` directly from PyPI:
19
+
20
+ ```bash
21
+ pip install gridfm-graphkit
22
+ ```
23
+
24
+ To contribute or develop locally, clone the repository and install in editable mode:
25
+
26
+ ```bash
27
+ git clone git@github.com:gridfm/gridfm-graphkit.git
28
+ cd gridfm-graphkit
29
+ python -m venv venv
30
+ source venv/bin/activate
31
+ pip install -e .
32
+ ```
33
+
34
+ For documentation generation and unit testing, install with the optional `dev` and `test` extras:
35
+
36
+ ```bash
37
+ pip install -e .[dev,test]
38
+ ```
39
+
40
+
41
+ # CLI commands
42
+
43
+ An interface to train, fine-tune, and evaluate GridFM models using configurable YAML files and MLflow tracking.
44
+
45
+ ```bash
46
+ gridfm_graphkit <command> [OPTIONS]
47
+ ```
48
+
49
+ Available commands:
50
+
51
+ * `train` – Train a new model from scrathc
52
+ * `finetune` – Fine-tune an existing pre-trained model
53
+ * `evaluate` – Evaluate model performance on a dataset
54
+ * `predict` – Run inference and save predictions
55
+
56
+ ---
57
+
58
+ ## Training Models
59
+
60
+ ```bash
61
+ gridfm_graphkit train --config path/to/config.yaml
62
+ ```
63
+
64
+ ### Arguments
65
+
66
+ | Argument | Type | Description | Default |
67
+ | ---------------- | ------ | ---------------------------------------------------------------- | ------- |
68
+ | `--config` | `str` | **Required**. Path to the training configuration YAML file. | `None` |
69
+ | `--exp_name` | `str` | **Optional**. MLflow experiment name. | `timestamp` |
70
+ | `--run_name` | `str` | **Optional**. MLflow run name. | `run` |
71
+ | `--log_dir ` | `str` | **Optional**. MLflow logging directory. | `mlruns` |
72
+ | `--data_path` | `str` | **Optional**. Root dataset directory. | `data` |
73
+
74
+ ### Examples
75
+
76
+ **Standard Training:**
77
+
78
+ ```bash
79
+ gridfm_graphkit train --config examples/config/case30_ieee_base.yaml --data_path examples/data
80
+ ```
81
+
82
+ ---
83
+
84
+ ## Fine-Tuning Models
85
+
86
+ ```bash
87
+ gridfm_graphkit finetune --config path/to/config.yaml --model_path path/to/model.pth
88
+ ```
89
+
90
+ ### Arguments
91
+
92
+ | Argument | Type | Description | Default |
93
+ | -------------- | ----- | ----------------------------------------------- | --------- |
94
+ | `--config` | `str` | **Required**. Fine-tuning configuration file. | `None` |
95
+ | `--model_path` | `str` | **Required**. Path to a pre-trained model file. | `None` |
96
+ | `--exp_name` | `str` | MLflow experiment name. | timestamp |
97
+ | `--run_name` | `str` | MLflow run name. | `run` |
98
+ | `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
99
+ | `--data_path` | `str` | Root dataset directory. | `data` |
100
+
101
+
102
+ ---
103
+
104
+ ## Evaluating Models
105
+
106
+ ```bash
107
+ gridfm_graphkit evaluate --config path/to/eval.yaml --model_path path/to/model.pth
108
+ ```
109
+
110
+ ### Arguments
111
+
112
+ | Argument | Type | Description | Default |
113
+ | -------------- | ----- | ---------------------------------------- | --------- |
114
+ | `--config` | `str` | **Required**. Path to evaluation config. | `None` |
115
+ | `--model_path` | `str` | Path to the trained model file. | `None` |
116
+ | `--exp_name` | `str` | MLflow experiment name. | timestamp |
117
+ | `--run_name` | `str` | MLflow run name. | `run` |
118
+ | `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
119
+ | `--data_path` | `str` | Dataset directory. | `data` |
120
+
121
+ ---
122
+
123
+ ## Running Predictions
124
+
125
+ ```bash
126
+ gridfm_graphkit predict --config path/to/config.yaml --model_path path/to/model.pth
127
+ ```
128
+
129
+ ### Arguments
130
+
131
+ | Argument | Type | Description | Default |
132
+ | --------------- | ----- | --------------------------------------------- | --------- |
133
+ | `--config` | `str` | **Required**. Path to prediction config file. | `None` |
134
+ | `--model_path` | `str` | Path to the trained model file. | `None` |
135
+ | `--exp_name` | `str` | MLflow experiment name. | timestamp |
136
+ | `--run_name` | `str` | MLflow run name. | `run` |
137
+ | `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
138
+ | `--data_path` | `str` | Dataset directory. | `data` |
139
+ | `--output_path` | `str` | Directory where predictions are saved. | `data` |
140
+
141
+ ---
@@ -0,0 +1,6 @@
1
+ import gridfm_graphkit.datasets
2
+ import gridfm_graphkit.models
3
+
4
+ __all__ = [
5
+ "gridfm_graphkit",
6
+ ]
@@ -0,0 +1,58 @@
1
+ import argparse
2
+ from datetime import datetime
3
+ from gridfm_graphkit.cli import main_cli
4
+
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser(
8
+ prog="gridfm_graphkit",
9
+ description="gridfm-graphkit CLI",
10
+ )
11
+ subparsers = parser.add_subparsers(dest="command", required=True)
12
+ exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
13
+
14
+ # ---- TRAIN SUBCOMMAND ----
15
+ train_parser = subparsers.add_parser("train", help="Run training")
16
+ train_parser.add_argument("--config", type=str, required=True)
17
+ train_parser.add_argument("--exp_name", type=str, default=exp_name)
18
+ train_parser.add_argument("--run_name", type=str, default="run")
19
+ train_parser.add_argument("--log_dir", type=str, default="mlruns")
20
+ train_parser.add_argument("--data_path", type=str, default="data")
21
+
22
+ # ---- FINETUNE SUBCOMMAND ----
23
+ finetune_parser = subparsers.add_parser("finetune", help="Run fine-tuning")
24
+ finetune_parser.add_argument("--config", type=str, required=True)
25
+ finetune_parser.add_argument("--model_path", type=str, required=True)
26
+ finetune_parser.add_argument("--exp_name", type=str, default=exp_name)
27
+ finetune_parser.add_argument("--run_name", type=str, default="run")
28
+ finetune_parser.add_argument("--log_dir", type=str, default="mlruns")
29
+ finetune_parser.add_argument("--data_path", type=str, default="data")
30
+
31
+ # ---- EVALUATE SUBCOMMAND ----
32
+ evaluate_parser = subparsers.add_parser(
33
+ "evaluate",
34
+ help="Evaluate model performance",
35
+ )
36
+ evaluate_parser.add_argument("--model_path", type=str, default=None)
37
+ evaluate_parser.add_argument("--config", type=str, required=True)
38
+ evaluate_parser.add_argument("--exp_name", type=str, default=exp_name)
39
+ evaluate_parser.add_argument("--run_name", type=str, default="run")
40
+ evaluate_parser.add_argument("--log_dir", type=str, default="mlruns")
41
+ evaluate_parser.add_argument("--data_path", type=str, default="data")
42
+
43
+ # ---- PREDICT SUBCOMMAND ----
44
+ predict_parser = subparsers.add_parser("predict", help="Evaluate model performance")
45
+ predict_parser.add_argument("--model_path", type=str, required=None)
46
+ predict_parser.add_argument("--config", type=str, required=True)
47
+ predict_parser.add_argument("--exp_name", type=str, default=exp_name)
48
+ predict_parser.add_argument("--run_name", type=str, default="run")
49
+ predict_parser.add_argument("--log_dir", type=str, default="mlruns")
50
+ predict_parser.add_argument("--data_path", type=str, default="data")
51
+ predict_parser.add_argument("--output_path", type=str, default="data")
52
+
53
+ args = parser.parse_args()
54
+ main_cli(args)
55
+
56
+
57
+ if __name__ == "__main__":
58
+ main()
@@ -0,0 +1,134 @@
1
+ from gridfm_graphkit.datasets.powergrid_datamodule import LitGridDataModule
2
+ from gridfm_graphkit.io.param_handler import NestedNamespace
3
+ from gridfm_graphkit.training.callbacks import SaveBestModelStateDict
4
+ import numpy as np
5
+ import os
6
+ import yaml
7
+ import torch
8
+ import random
9
+ import pandas as pd
10
+
11
+ from gridfm_graphkit.tasks.feature_reconstruction_task import FeatureReconstructionTask
12
+ from lightning.pytorch.callbacks.early_stopping import EarlyStopping
13
+ from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
14
+ from lightning.pytorch.loggers import MLFlowLogger
15
+ import lightning as L
16
+
17
+
18
+ def get_training_callbacks(args):
19
+ early_stop_callback = EarlyStopping(
20
+ monitor="Validation loss",
21
+ min_delta=args.callbacks.tol,
22
+ patience=args.callbacks.patience,
23
+ verbose=False,
24
+ mode="min",
25
+ )
26
+
27
+ save_best_model_callback = SaveBestModelStateDict(
28
+ monitor="Validation loss",
29
+ mode="min",
30
+ filename="best_model_state_dict.pt",
31
+ )
32
+
33
+ checkpoint_callback = ModelCheckpoint(
34
+ monitor="Validation loss", # or whichever metric you track
35
+ mode="min",
36
+ save_last=True,
37
+ save_top_k=0,
38
+ )
39
+
40
+ return [early_stop_callback, save_best_model_callback, checkpoint_callback]
41
+
42
+
43
+ def main_cli(args):
44
+ logger = MLFlowLogger(
45
+ save_dir=args.log_dir,
46
+ experiment_name=args.exp_name,
47
+ run_name=args.run_name,
48
+ )
49
+
50
+ with open(args.config, "r") as f:
51
+ base_config = yaml.safe_load(f)
52
+
53
+ config_args = NestedNamespace(**base_config)
54
+
55
+ torch.manual_seed(config_args.seed)
56
+ random.seed(config_args.seed)
57
+ np.random.seed(config_args.seed)
58
+
59
+ litGrid = LitGridDataModule(config_args, args.data_path)
60
+ model = FeatureReconstructionTask(
61
+ config_args,
62
+ litGrid.node_normalizers,
63
+ litGrid.edge_normalizers,
64
+ )
65
+ if args.command != "train":
66
+ print(f"Loading model weights from {args.model_path}")
67
+ state_dict = torch.load(args.model_path)
68
+ model.load_state_dict(state_dict)
69
+
70
+ trainer = L.Trainer(
71
+ logger=logger,
72
+ accelerator=config_args.training.accelerator,
73
+ devices=config_args.training.devices,
74
+ strategy=config_args.training.strategy,
75
+ log_every_n_steps=1,
76
+ default_root_dir=args.log_dir,
77
+ max_epochs=config_args.training.epochs,
78
+ callbacks=get_training_callbacks(config_args),
79
+ )
80
+ if args.command == "train" or args.command == "finetune":
81
+ trainer.fit(model=model, datamodule=litGrid)
82
+
83
+ if args.command != "predict":
84
+ trainer.test(model=model, datamodule=litGrid)
85
+
86
+ if args.command == "predict":
87
+ predictions = trainer.predict(model=model, datamodule=litGrid)
88
+ all_outputs = []
89
+ all_mask_PQ = []
90
+ all_mask_PV = []
91
+ all_mask_REF = []
92
+ all_scenarios = []
93
+ all_bus_numbers = []
94
+
95
+ for batch in predictions:
96
+ all_outputs.append(batch["output"])
97
+ all_mask_PQ.append(batch["mask_PQ"])
98
+ all_mask_PV.append(batch["mask_PV"])
99
+ all_mask_REF.append(batch["mask_REF"])
100
+ all_scenarios.append(batch["scenario_id"])
101
+ all_bus_numbers.append(batch["bus_number"])
102
+
103
+ # Concatenate all
104
+ outputs = np.concatenate(all_outputs, axis=0) # shape: [num_nodes, 6]
105
+ mask_PQ = np.concatenate(all_mask_PQ, axis=0)
106
+ mask_PV = np.concatenate(all_mask_PV, axis=0)
107
+ mask_REF = np.concatenate(all_mask_REF, axis=0)
108
+ scenario_ids = np.concatenate(all_scenarios, axis=0)
109
+ bus_numbers = np.concatenate(all_bus_numbers, axis=0)
110
+
111
+ # Build DataFrame
112
+ df = pd.DataFrame(
113
+ {
114
+ "scenario": scenario_ids,
115
+ "bus": bus_numbers,
116
+ "PD": outputs[:, 0],
117
+ "QD": outputs[:, 1],
118
+ "PG": outputs[:, 2],
119
+ "QG": outputs[:, 3],
120
+ "VM": outputs[:, 4],
121
+ "VA": outputs[:, 5],
122
+ "PQ": mask_PQ.astype(int),
123
+ "PV": mask_PV.astype(int),
124
+ "REF": mask_REF.astype(int),
125
+ },
126
+ )
127
+
128
+ # Save CSV
129
+ output_dir = os.path.join(args.output_path)
130
+ os.makedirs(output_dir, exist_ok=True)
131
+ csv_path = os.path.join(output_dir, "predictions.csv")
132
+ df.to_csv(csv_path, index=False)
133
+
134
+ print(f"Saved predictions to {csv_path}")
@@ -0,0 +1,23 @@
1
+ from gridfm_graphkit.datasets.transforms import (
2
+ AddPFMask,
3
+ AddIdentityMask,
4
+ AddRandomMask,
5
+ AddOPFMask,
6
+ )
7
+ from gridfm_graphkit.datasets.normalizers import (
8
+ Standardizer,
9
+ MinMaxNormalizer,
10
+ BaseMVANormalizer,
11
+ IdentityNormalizer,
12
+ )
13
+
14
+ __all__ = [
15
+ "AddPFMask",
16
+ "AddIdentityMask",
17
+ "AddRandomMask",
18
+ "AddOPFMask",
19
+ "Standardizer",
20
+ "MinMaxNormalizer",
21
+ "BaseMVANormalizer",
22
+ "IdentityNormalizer",
23
+ ]
@@ -1,4 +1,5 @@
1
1
  from gridfm_graphkit.datasets.globals import PD, QD, PG, QG, VA
2
+ from gridfm_graphkit.io.registries import NORMALIZERS_REGISTRY
2
3
  import torch
3
4
  from abc import ABC, abstractmethod
4
5
 
@@ -53,13 +54,25 @@ class Normalizer(ABC):
53
54
  Original tensor.
54
55
  """
55
56
 
57
+ @abstractmethod
58
+ def get_stats(self) -> dict:
59
+ """
60
+ Return the stored normalization statistics for logging/inspection.
61
+ """
62
+
56
63
 
64
+ @NORMALIZERS_REGISTRY.register("minmax")
57
65
  class MinMaxNormalizer(Normalizer):
58
66
  """
59
67
  Scales each feature to the [0, 1] range.
68
+
69
+ Args:
70
+ node_data (bool): Whether data is node-level or edge-level
71
+ args (NestedNamespace): Parameters
72
+
60
73
  """
61
74
 
62
- def __init__(self):
75
+ def __init__(self, node_data: bool, args):
63
76
  self.min_val = None
64
77
  self.max_val = None
65
78
 
@@ -95,13 +108,25 @@ class MinMaxNormalizer(Normalizer):
95
108
  diff[diff == 0] = 1
96
109
  return (normalized_data * diff) + self.min_val
97
110
 
111
+ def get_stats(self) -> dict:
112
+ return {
113
+ "min_value": self.min_val.tolist() if self.min_val is not None else None,
114
+ "max_value": self.max_val.tolist() if self.max_val is not None else None,
115
+ }
98
116
 
117
+
118
+ @NORMALIZERS_REGISTRY.register("standard")
99
119
  class Standardizer(Normalizer):
100
120
  """
101
121
  Standardizes each feature to zero mean and unit variance.
122
+
123
+ Args:
124
+ node_data (bool): Whether data is node-level or edge-level
125
+ args (NestedNamespace): Parameters
126
+
102
127
  """
103
128
 
104
- def __init__(self):
129
+ def __init__(self, node_data: bool, args):
105
130
  self.mean = None
106
131
  self.std = None
107
132
 
@@ -137,7 +162,14 @@ class Standardizer(Normalizer):
137
162
  std[std == 0] = 1
138
163
  return (normalized_data * std) + self.mean
139
164
 
165
+ def get_stats(self) -> dict:
166
+ return {
167
+ "mean": self.mean.tolist() if self.mean is not None else None,
168
+ "std": self.std.tolist() if self.std is not None else None,
169
+ }
170
+
140
171
 
172
+ @NORMALIZERS_REGISTRY.register("baseMVAnorm")
141
173
  class BaseMVANormalizer(Normalizer):
142
174
  """
143
175
  In power systems, a suitable normalization strategy must preserve the physical properties of
@@ -148,14 +180,17 @@ class BaseMVANormalizer(Normalizer):
148
180
  preserving fundamental physical relationships.
149
181
  """
150
182
 
151
- def __init__(self, node_data: bool, baseMVA_orig: float = 100.0):
183
+ def __init__(self, node_data: bool, args):
152
184
  """
153
185
  Args:
154
- node_data: Whether data is node-level or edge-level (PD, QD, PG, QG, VA).
155
- baseMVA_orig: Original baseMVA (e.g. from MATPOWER).
186
+ node_data: Whether data is node-level or edge-level
187
+ args (NestedNamespace): Parameters
188
+
189
+ Attributes:
190
+ baseMVA (float): baseMVA found in casefile. From ``args.data.baseMVA``.
156
191
  """
157
192
  self.node_data = node_data
158
- self.baseMVA_orig = baseMVA_orig
193
+ self.baseMVA_orig = getattr(args.data, "baseMVA", 100)
159
194
  self.baseMVA = None
160
195
 
161
196
  def to(self, device):
@@ -208,12 +243,26 @@ class BaseMVANormalizer(Normalizer):
208
243
 
209
244
  return normalized_data
210
245
 
246
+ def get_stats(self) -> dict:
247
+ return {
248
+ "baseMVA": self.baseMVA,
249
+ "baseMVA_orig": self.baseMVA_orig,
250
+ }
251
+
211
252
 
253
+ @NORMALIZERS_REGISTRY.register("identity")
212
254
  class IdentityNormalizer(Normalizer):
213
255
  """
214
256
  No normalization: returns data unchanged.
257
+
258
+ Args:
259
+ node_data: Whether data is node-level or edge-level
260
+ args (NestedNamespace): Parameters
215
261
  """
216
262
 
263
+ def __init__(self, node_data: bool, args):
264
+ pass
265
+
217
266
  def fit(self, data: torch.Tensor) -> dict:
218
267
  return {}
219
268
 
@@ -225,3 +274,6 @@ class IdentityNormalizer(Normalizer):
225
274
 
226
275
  def inverse_transform(self, normalized_data: torch.Tensor) -> torch.Tensor:
227
276
  return normalized_data
277
+
278
+ def get_stats(self) -> dict:
279
+ return {"note": "No normalization applied."}