gridfm-graphkit 0.0.3__tar.gz → 0.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.5}/PKG-INFO +52 -44
  2. gridfm_graphkit-0.0.5/README.md +141 -0
  3. gridfm_graphkit-0.0.5/gridfm_graphkit/__init__.py +6 -0
  4. gridfm_graphkit-0.0.5/gridfm_graphkit/__main__.py +58 -0
  5. gridfm_graphkit-0.0.5/gridfm_graphkit/cli.py +134 -0
  6. gridfm_graphkit-0.0.5/gridfm_graphkit/datasets/__init__.py +23 -0
  7. gridfm_graphkit-0.0.3/gridfm_graphkit/datasets/data_normalization.py → gridfm_graphkit-0.0.5/gridfm_graphkit/datasets/normalizers.py +58 -6
  8. gridfm_graphkit-0.0.5/gridfm_graphkit/datasets/postprocessing.py +83 -0
  9. gridfm_graphkit-0.0.5/gridfm_graphkit/datasets/powergrid_datamodule.py +207 -0
  10. gridfm_graphkit-0.0.3/gridfm_graphkit/datasets/powergrid.py → gridfm_graphkit-0.0.5/gridfm_graphkit/datasets/powergrid_dataset.py +57 -22
  11. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.5}/gridfm_graphkit/datasets/transforms.py +17 -3
  12. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.5}/gridfm_graphkit/datasets/utils.py +0 -16
  13. gridfm_graphkit-0.0.5/gridfm_graphkit/io/param_handler.py +138 -0
  14. gridfm_graphkit-0.0.5/gridfm_graphkit/io/registries.py +42 -0
  15. gridfm_graphkit-0.0.5/gridfm_graphkit/models/__init__.py +4 -0
  16. gridfm_graphkit-0.0.3/gridfm_graphkit/models/graphTransformer.py → gridfm_graphkit-0.0.5/gridfm_graphkit/models/gnn_transformer.py +34 -34
  17. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.5}/gridfm_graphkit/models/gps_transformer.py +40 -43
  18. gridfm_graphkit-0.0.5/gridfm_graphkit/tasks/feature_reconstruction_task.py +366 -0
  19. gridfm_graphkit-0.0.5/gridfm_graphkit/training/callbacks.py +49 -0
  20. {gridfm_graphkit-0.0.3/gridfm_graphkit/utils → gridfm_graphkit-0.0.5/gridfm_graphkit/training}/loss.py +9 -9
  21. gridfm_graphkit-0.0.5/gridfm_graphkit/utils/utils.py +42 -0
  22. gridfm_graphkit-0.0.5/gridfm_graphkit/utils/visualization.py +513 -0
  23. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.5}/gridfm_graphkit.egg-info/PKG-INFO +52 -44
  24. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.5}/gridfm_graphkit.egg-info/SOURCES.txt +14 -9
  25. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.5}/gridfm_graphkit.egg-info/requires.txt +2 -0
  26. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.5}/pyproject.toml +3 -1
  27. gridfm_graphkit-0.0.5/tests/test_data_module.py +60 -0
  28. gridfm_graphkit-0.0.5/tests/test_full_pipeline.py +82 -0
  29. gridfm_graphkit-0.0.5/tests/test_losses.py +36 -0
  30. gridfm_graphkit-0.0.5/tests/test_model_outputs.py +69 -0
  31. gridfm_graphkit-0.0.5/tests/test_normalization.py +36 -0
  32. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.5}/tests/test_yaml_configs.py +23 -0
  33. gridfm_graphkit-0.0.3/README.md +0 -135
  34. gridfm_graphkit-0.0.3/gridfm_graphkit/__main__.py +0 -62
  35. gridfm_graphkit-0.0.3/gridfm_graphkit/cli.py +0 -530
  36. gridfm_graphkit-0.0.3/gridfm_graphkit/evaluation/node_level.py +0 -334
  37. gridfm_graphkit-0.0.3/gridfm_graphkit/io/param_handler.py +0 -293
  38. gridfm_graphkit-0.0.3/gridfm_graphkit/models/__init__.py +0 -0
  39. gridfm_graphkit-0.0.3/gridfm_graphkit/training/__init__.py +0 -0
  40. gridfm_graphkit-0.0.3/gridfm_graphkit/training/callbacks.py +0 -47
  41. gridfm_graphkit-0.0.3/gridfm_graphkit/training/plugins.py +0 -218
  42. gridfm_graphkit-0.0.3/gridfm_graphkit/training/trainer.py +0 -156
  43. gridfm_graphkit-0.0.3/gridfm_graphkit/utils/__init__.py +0 -0
  44. gridfm_graphkit-0.0.3/gridfm_graphkit/utils/visualization.py +0 -324
  45. gridfm_graphkit-0.0.3/tests/test_model_outputs.py +0 -55
  46. gridfm_graphkit-0.0.3/tests/test_training.py +0 -90
  47. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.5}/LICENSE +0 -0
  48. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.5}/gridfm_graphkit/datasets/globals.py +0 -0
  49. {gridfm_graphkit-0.0.3/gridfm_graphkit → gridfm_graphkit-0.0.5/gridfm_graphkit/io}/__init__.py +0 -0
  50. {gridfm_graphkit-0.0.3/gridfm_graphkit/datasets → gridfm_graphkit-0.0.5/gridfm_graphkit/tasks}/__init__.py +0 -0
  51. {gridfm_graphkit-0.0.3/gridfm_graphkit/evaluation → gridfm_graphkit-0.0.5/gridfm_graphkit/training}/__init__.py +0 -0
  52. {gridfm_graphkit-0.0.3/gridfm_graphkit/io → gridfm_graphkit-0.0.5/gridfm_graphkit/utils}/__init__.py +0 -0
  53. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.5}/gridfm_graphkit.egg-info/dependency_links.txt +0 -0
  54. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.5}/gridfm_graphkit.egg-info/entry_points.txt +0 -0
  55. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.5}/gridfm_graphkit.egg-info/top_level.txt +0 -0
  56. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.5}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gridfm-graphkit
3
- Version: 0.0.3
3
+ Version: 0.0.5
4
4
  Summary: Grid Foundation Model
5
5
  Author-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>, Alban Puech <apuech@seas.harvard.edu>, Tamara Govindasamy <tamara.govindasamy@ibm.com>, Mangaliso Mngomezulu <mngomezulum@ibm.com>, Etienne Vos <etienne.vos@ibm.com>, Celia Cintas <celia.cintas@ibm.com>, Jonas Weiss <jwe@zurich.ibm.com>
6
6
  Maintainer-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>
@@ -25,6 +25,8 @@ Requires-Dist: torch>=2.7.1
25
25
  Requires-Dist: torch-geometric>=2.6.1
26
26
  Requires-Dist: torchaudio>=2.7.1
27
27
  Requires-Dist: torchvision>=0.22.1
28
+ Requires-Dist: lightning
29
+ Requires-Dist: seaborn
28
30
  Provides-Extra: dev
29
31
  Requires-Dist: mkdocs-material; extra == "dev"
30
32
  Requires-Dist: mkdocstrings[python]; extra == "dev"
@@ -38,6 +40,9 @@ Dynamic: license-file
38
40
 
39
41
  # gridfm-graphkit
40
42
  [![Docs](https://img.shields.io/badge/docs-available-brightgreen)](https://gridfm.github.io/gridfm-graphkit/)
43
+ ![Coverage](https://img.shields.io/badge/coverage-83%25-yellowgreen)
44
+ ![Python](https://img.shields.io/badge/python-3.10%20%E2%80%93%203.12-blue)
45
+ ![License](https://img.shields.io/badge/license-Apache%202.0-blue)
41
46
 
42
47
  This library is brought to you by the GridFM team to train, finetune and interact with a foundation model for the electric power grid.
43
48
 
@@ -73,7 +78,7 @@ pip install -e .[dev,test]
73
78
  ```
74
79
 
75
80
 
76
- # gridfm-graphkit CLI
81
+ # CLI commands
77
82
 
78
83
  An interface to train, fine-tune, and evaluate GridFM models using configurable YAML files and MLflow tracking.
79
84
 
@@ -83,9 +88,10 @@ gridfm_graphkit <command> [OPTIONS]
83
88
 
84
89
  Available commands:
85
90
 
86
- * `train` – Train a new model
87
- * `predict` – Evaluate an existing model
88
- * `finetune` – Fine-tune a pre-trained model
91
+ * `train` – Train a new model from scrathc
92
+ * `finetune` – Fine-tune an existing pre-trained model
93
+ * `evaluate` – Evaluate model performance on a dataset
94
+ * `predict` – Run inference and save predictions
89
95
 
90
96
  ---
91
97
 
@@ -99,75 +105,77 @@ gridfm_graphkit train --config path/to/config.yaml
99
105
 
100
106
  | Argument | Type | Description | Default |
101
107
  | ---------------- | ------ | ---------------------------------------------------------------- | ------- |
102
- | `--config` | `str` | **Required for standard training**. Path to base config YAML. | `None` |
103
- | `--grid` | `str` | **Optional**. Path to grid search YAML. Not supported with `-c`. | `None` |
104
- | `--exp` | `str` | **Optional**. MLflow experiment name. Defaults to a timestamp. | `None` |
108
+ | `--config` | `str` | **Required**. Path to the training configuration YAML file. | `None` |
109
+ | `--exp_name` | `str` | **Optional**. MLflow experiment name. | `timestamp` |
110
+ | `--run_name` | `str` | **Optional**. MLflow run name. | `run` |
111
+ | `--log_dir ` | `str` | **Optional**. MLflow logging directory. | `mlruns` |
105
112
  | `--data_path` | `str` | **Optional**. Root dataset directory. | `data` |
106
- | `-c` | `flag` | **Optional**. Enable checkpoint mode. | `False` |
107
- | `--model_exp_id` | `str` | **Required if `-c` is used**. MLflow experiment ID. | `None` |
108
- | `--model_run_id` | `str` | **Required if `-c` is used**. MLflow run ID. | `None` |
109
113
 
110
114
  ### Examples
111
115
 
112
116
  **Standard Training:**
113
117
 
114
118
  ```bash
115
- gridfm_graphkit train --config config/train.yaml --exp "run1"
119
+ gridfm_graphkit train --config examples/config/case30_ieee_base.yaml --data_path examples/data
116
120
  ```
117
121
 
118
- **Grid Search Training:**
122
+ ---
123
+
124
+ ## Fine-Tuning Models
119
125
 
120
126
  ```bash
121
- gridfm_graphkit train --config config/train.yaml --grid config/grid.yaml
127
+ gridfm_graphkit finetune --config path/to/config.yaml --model_path path/to/model.pth
122
128
  ```
123
129
 
124
- **Training from Checkpoint:**
130
+ ### Arguments
131
+
132
+ | Argument | Type | Description | Default |
133
+ | -------------- | ----- | ----------------------------------------------- | --------- |
134
+ | `--config` | `str` | **Required**. Fine-tuning configuration file. | `None` |
135
+ | `--model_path` | `str` | **Required**. Path to a pre-trained model file. | `None` |
136
+ | `--exp_name` | `str` | MLflow experiment name. | timestamp |
137
+ | `--run_name` | `str` | MLflow run name. | `run` |
138
+ | `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
139
+ | `--data_path` | `str` | Root dataset directory. | `data` |
125
140
 
126
- ```bash
127
- gridfm_graphkit train -c --model_exp_id 123 --model_run_id abc
128
- ```
129
141
 
130
142
  ---
131
143
 
132
144
  ## Evaluating Models
133
145
 
134
146
  ```bash
135
- gridfm_graphkit predict --model_path model.pth --config config/eval.yaml --eval_name run_eval
147
+ gridfm_graphkit evaluate --config path/to/eval.yaml --model_path path/to/model.pth
136
148
  ```
137
149
 
138
150
  ### Arguments
139
151
 
140
- | Argument | Type | Description | Default |
141
- | ---------------- | ----- | ----------------------------------------------------------------- | ------------ |
142
- | `--model_path` | `str` | **Optional**. Path to a model file. | `None` |
143
- | `--model_exp_id` | `str` | **Required if `--model_path` is not used**. MLflow experiment ID. | `None` |
144
- | `--model_run_id` | `str` | **Required if `--model_path` is not used**. MLflow run ID. | `None` |
145
- | `--model_name` | `str` | **Optional**. Filename inside MLflow artifacts. | `best_model` |
146
- | `--config` | `str` | **Required**. Path to evaluation config. | `None` |
147
- | `--eval_name` | `str` | **Required**. Name of the evaluation run in MLflow. | `None` |
148
- | `--data_path` | `str` | **Optional**. Path to dataset directory. | `data` |
149
-
150
- ### Examples
151
-
152
- **Evaluate a Logged MLflow Model:**
153
-
154
- ```bash
155
- gridfm_graphkit predict --config config/eval.yaml --eval_name run_eval --model_exp_id 1 --model_run_id abc
156
- ```
152
+ | Argument | Type | Description | Default |
153
+ | -------------- | ----- | ---------------------------------------- | --------- |
154
+ | `--config` | `str` | **Required**. Path to evaluation config. | `None` |
155
+ | `--model_path` | `str` | Path to the trained model file. | `None` |
156
+ | `--exp_name` | `str` | MLflow experiment name. | timestamp |
157
+ | `--run_name` | `str` | MLflow run name. | `run` |
158
+ | `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
159
+ | `--data_path` | `str` | Dataset directory. | `data` |
157
160
 
158
161
  ---
159
162
 
160
- ## Fine-Tuning Models
163
+ ## Running Predictions
161
164
 
162
165
  ```bash
163
- gridfm_graphkit finetune --config path/to/config.yaml --model_path path/to/model.pth
166
+ gridfm_graphkit predict --config path/to/config.yaml --model_path path/to/model.pth
164
167
  ```
165
168
 
166
169
  ### Arguments
167
170
 
168
- | Argument | Type | Description | Default |
169
- | -------------- | ----- | ----------------------------------------------- | ------- |
170
- | `--config` | `str` | **Required**. Fine-tuning configuration file. | `None` |
171
- | `--model_path` | `str` | **Required**. Path to a pre-trained model file. | `None` |
172
- | `--exp` | `str` | **Optional**. MLflow experiment name. | `None` |
173
- | `--data_path` | `str` | **Optional**. Root dataset directory. | `data` |
171
+ | Argument | Type | Description | Default |
172
+ | --------------- | ----- | --------------------------------------------- | --------- |
173
+ | `--config` | `str` | **Required**. Path to prediction config file. | `None` |
174
+ | `--model_path` | `str` | Path to the trained model file. | `None` |
175
+ | `--exp_name` | `str` | MLflow experiment name. | timestamp |
176
+ | `--run_name` | `str` | MLflow run name. | `run` |
177
+ | `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
178
+ | `--data_path` | `str` | Dataset directory. | `data` |
179
+ | `--output_path` | `str` | Directory where predictions are saved. | `data` |
180
+
181
+ ---
@@ -0,0 +1,141 @@
1
+ # gridfm-graphkit
2
+ [![Docs](https://img.shields.io/badge/docs-available-brightgreen)](https://gridfm.github.io/gridfm-graphkit/)
3
+ ![Coverage](https://img.shields.io/badge/coverage-83%25-yellowgreen)
4
+ ![Python](https://img.shields.io/badge/python-3.10%20%E2%80%93%203.12-blue)
5
+ ![License](https://img.shields.io/badge/license-Apache%202.0-blue)
6
+
7
+ This library is brought to you by the GridFM team to train, finetune and interact with a foundation model for the electric power grid.
8
+
9
+ ---
10
+
11
+ <p align="center">
12
+ <img src="https://raw.githubusercontent.com/gridfm/gridfm-graphkit/refs/heads/main/docs/figs/pre_training.png" alt="GridFM logo"/>
13
+ <br/>
14
+ </p>
15
+
16
+ # Installation
17
+
18
+ You can install `gridfm-graphkit` directly from PyPI:
19
+
20
+ ```bash
21
+ pip install gridfm-graphkit
22
+ ```
23
+
24
+ To contribute or develop locally, clone the repository and install in editable mode:
25
+
26
+ ```bash
27
+ git clone git@github.com:gridfm/gridfm-graphkit.git
28
+ cd gridfm-graphkit
29
+ python -m venv venv
30
+ source venv/bin/activate
31
+ pip install -e .
32
+ ```
33
+
34
+ For documentation generation and unit testing, install with the optional `dev` and `test` extras:
35
+
36
+ ```bash
37
+ pip install -e .[dev,test]
38
+ ```
39
+
40
+
41
+ # CLI commands
42
+
43
+ An interface to train, fine-tune, and evaluate GridFM models using configurable YAML files and MLflow tracking.
44
+
45
+ ```bash
46
+ gridfm_graphkit <command> [OPTIONS]
47
+ ```
48
+
49
+ Available commands:
50
+
51
+ * `train` – Train a new model from scrathc
52
+ * `finetune` – Fine-tune an existing pre-trained model
53
+ * `evaluate` – Evaluate model performance on a dataset
54
+ * `predict` – Run inference and save predictions
55
+
56
+ ---
57
+
58
+ ## Training Models
59
+
60
+ ```bash
61
+ gridfm_graphkit train --config path/to/config.yaml
62
+ ```
63
+
64
+ ### Arguments
65
+
66
+ | Argument | Type | Description | Default |
67
+ | ---------------- | ------ | ---------------------------------------------------------------- | ------- |
68
+ | `--config` | `str` | **Required**. Path to the training configuration YAML file. | `None` |
69
+ | `--exp_name` | `str` | **Optional**. MLflow experiment name. | `timestamp` |
70
+ | `--run_name` | `str` | **Optional**. MLflow run name. | `run` |
71
+ | `--log_dir ` | `str` | **Optional**. MLflow logging directory. | `mlruns` |
72
+ | `--data_path` | `str` | **Optional**. Root dataset directory. | `data` |
73
+
74
+ ### Examples
75
+
76
+ **Standard Training:**
77
+
78
+ ```bash
79
+ gridfm_graphkit train --config examples/config/case30_ieee_base.yaml --data_path examples/data
80
+ ```
81
+
82
+ ---
83
+
84
+ ## Fine-Tuning Models
85
+
86
+ ```bash
87
+ gridfm_graphkit finetune --config path/to/config.yaml --model_path path/to/model.pth
88
+ ```
89
+
90
+ ### Arguments
91
+
92
+ | Argument | Type | Description | Default |
93
+ | -------------- | ----- | ----------------------------------------------- | --------- |
94
+ | `--config` | `str` | **Required**. Fine-tuning configuration file. | `None` |
95
+ | `--model_path` | `str` | **Required**. Path to a pre-trained model file. | `None` |
96
+ | `--exp_name` | `str` | MLflow experiment name. | timestamp |
97
+ | `--run_name` | `str` | MLflow run name. | `run` |
98
+ | `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
99
+ | `--data_path` | `str` | Root dataset directory. | `data` |
100
+
101
+
102
+ ---
103
+
104
+ ## Evaluating Models
105
+
106
+ ```bash
107
+ gridfm_graphkit evaluate --config path/to/eval.yaml --model_path path/to/model.pth
108
+ ```
109
+
110
+ ### Arguments
111
+
112
+ | Argument | Type | Description | Default |
113
+ | -------------- | ----- | ---------------------------------------- | --------- |
114
+ | `--config` | `str` | **Required**. Path to evaluation config. | `None` |
115
+ | `--model_path` | `str` | Path to the trained model file. | `None` |
116
+ | `--exp_name` | `str` | MLflow experiment name. | timestamp |
117
+ | `--run_name` | `str` | MLflow run name. | `run` |
118
+ | `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
119
+ | `--data_path` | `str` | Dataset directory. | `data` |
120
+
121
+ ---
122
+
123
+ ## Running Predictions
124
+
125
+ ```bash
126
+ gridfm_graphkit predict --config path/to/config.yaml --model_path path/to/model.pth
127
+ ```
128
+
129
+ ### Arguments
130
+
131
+ | Argument | Type | Description | Default |
132
+ | --------------- | ----- | --------------------------------------------- | --------- |
133
+ | `--config` | `str` | **Required**. Path to prediction config file. | `None` |
134
+ | `--model_path` | `str` | Path to the trained model file. | `None` |
135
+ | `--exp_name` | `str` | MLflow experiment name. | timestamp |
136
+ | `--run_name` | `str` | MLflow run name. | `run` |
137
+ | `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
138
+ | `--data_path` | `str` | Dataset directory. | `data` |
139
+ | `--output_path` | `str` | Directory where predictions are saved. | `data` |
140
+
141
+ ---
@@ -0,0 +1,6 @@
1
+ import gridfm_graphkit.datasets
2
+ import gridfm_graphkit.models
3
+
4
+ __all__ = [
5
+ "gridfm_graphkit",
6
+ ]
@@ -0,0 +1,58 @@
1
+ import argparse
2
+ from datetime import datetime
3
+ from gridfm_graphkit.cli import main_cli
4
+
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser(
8
+ prog="gridfm_graphkit",
9
+ description="gridfm-graphkit CLI",
10
+ )
11
+ subparsers = parser.add_subparsers(dest="command", required=True)
12
+ exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
13
+
14
+ # ---- TRAIN SUBCOMMAND ----
15
+ train_parser = subparsers.add_parser("train", help="Run training")
16
+ train_parser.add_argument("--config", type=str, required=True)
17
+ train_parser.add_argument("--exp_name", type=str, default=exp_name)
18
+ train_parser.add_argument("--run_name", type=str, default="run")
19
+ train_parser.add_argument("--log_dir", type=str, default="mlruns")
20
+ train_parser.add_argument("--data_path", type=str, default="data")
21
+
22
+ # ---- FINETUNE SUBCOMMAND ----
23
+ finetune_parser = subparsers.add_parser("finetune", help="Run fine-tuning")
24
+ finetune_parser.add_argument("--config", type=str, required=True)
25
+ finetune_parser.add_argument("--model_path", type=str, required=True)
26
+ finetune_parser.add_argument("--exp_name", type=str, default=exp_name)
27
+ finetune_parser.add_argument("--run_name", type=str, default="run")
28
+ finetune_parser.add_argument("--log_dir", type=str, default="mlruns")
29
+ finetune_parser.add_argument("--data_path", type=str, default="data")
30
+
31
+ # ---- EVALUATE SUBCOMMAND ----
32
+ evaluate_parser = subparsers.add_parser(
33
+ "evaluate",
34
+ help="Evaluate model performance",
35
+ )
36
+ evaluate_parser.add_argument("--model_path", type=str, default=None)
37
+ evaluate_parser.add_argument("--config", type=str, required=True)
38
+ evaluate_parser.add_argument("--exp_name", type=str, default=exp_name)
39
+ evaluate_parser.add_argument("--run_name", type=str, default="run")
40
+ evaluate_parser.add_argument("--log_dir", type=str, default="mlruns")
41
+ evaluate_parser.add_argument("--data_path", type=str, default="data")
42
+
43
+ # ---- PREDICT SUBCOMMAND ----
44
+ predict_parser = subparsers.add_parser("predict", help="Evaluate model performance")
45
+ predict_parser.add_argument("--model_path", type=str, required=None)
46
+ predict_parser.add_argument("--config", type=str, required=True)
47
+ predict_parser.add_argument("--exp_name", type=str, default=exp_name)
48
+ predict_parser.add_argument("--run_name", type=str, default="run")
49
+ predict_parser.add_argument("--log_dir", type=str, default="mlruns")
50
+ predict_parser.add_argument("--data_path", type=str, default="data")
51
+ predict_parser.add_argument("--output_path", type=str, default="data")
52
+
53
+ args = parser.parse_args()
54
+ main_cli(args)
55
+
56
+
57
+ if __name__ == "__main__":
58
+ main()
@@ -0,0 +1,134 @@
1
+ from gridfm_graphkit.datasets.powergrid_datamodule import LitGridDataModule
2
+ from gridfm_graphkit.io.param_handler import NestedNamespace
3
+ from gridfm_graphkit.training.callbacks import SaveBestModelStateDict
4
+ import numpy as np
5
+ import os
6
+ import yaml
7
+ import torch
8
+ import random
9
+ import pandas as pd
10
+
11
+ from gridfm_graphkit.tasks.feature_reconstruction_task import FeatureReconstructionTask
12
+ from lightning.pytorch.callbacks.early_stopping import EarlyStopping
13
+ from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
14
+ from lightning.pytorch.loggers import MLFlowLogger
15
+ import lightning as L
16
+
17
+
18
+ def get_training_callbacks(args):
19
+ early_stop_callback = EarlyStopping(
20
+ monitor="Validation loss",
21
+ min_delta=args.callbacks.tol,
22
+ patience=args.callbacks.patience,
23
+ verbose=False,
24
+ mode="min",
25
+ )
26
+
27
+ save_best_model_callback = SaveBestModelStateDict(
28
+ monitor="Validation loss",
29
+ mode="min",
30
+ filename="best_model_state_dict.pt",
31
+ )
32
+
33
+ checkpoint_callback = ModelCheckpoint(
34
+ monitor="Validation loss", # or whichever metric you track
35
+ mode="min",
36
+ save_last=True,
37
+ save_top_k=0,
38
+ )
39
+
40
+ return [early_stop_callback, save_best_model_callback, checkpoint_callback]
41
+
42
+
43
+ def main_cli(args):
44
+ logger = MLFlowLogger(
45
+ save_dir=args.log_dir,
46
+ experiment_name=args.exp_name,
47
+ run_name=args.run_name,
48
+ )
49
+
50
+ with open(args.config, "r") as f:
51
+ base_config = yaml.safe_load(f)
52
+
53
+ config_args = NestedNamespace(**base_config)
54
+
55
+ torch.manual_seed(config_args.seed)
56
+ random.seed(config_args.seed)
57
+ np.random.seed(config_args.seed)
58
+
59
+ litGrid = LitGridDataModule(config_args, args.data_path)
60
+ model = FeatureReconstructionTask(
61
+ config_args,
62
+ litGrid.node_normalizers,
63
+ litGrid.edge_normalizers,
64
+ )
65
+ if args.command != "train":
66
+ print(f"Loading model weights from {args.model_path}")
67
+ state_dict = torch.load(args.model_path)
68
+ model.load_state_dict(state_dict)
69
+
70
+ trainer = L.Trainer(
71
+ logger=logger,
72
+ accelerator=config_args.training.accelerator,
73
+ devices=config_args.training.devices,
74
+ strategy=config_args.training.strategy,
75
+ log_every_n_steps=1,
76
+ default_root_dir=args.log_dir,
77
+ max_epochs=config_args.training.epochs,
78
+ callbacks=get_training_callbacks(config_args),
79
+ )
80
+ if args.command == "train" or args.command == "finetune":
81
+ trainer.fit(model=model, datamodule=litGrid)
82
+
83
+ if args.command != "predict":
84
+ trainer.test(model=model, datamodule=litGrid)
85
+
86
+ if args.command == "predict":
87
+ predictions = trainer.predict(model=model, datamodule=litGrid)
88
+ all_outputs = []
89
+ all_mask_PQ = []
90
+ all_mask_PV = []
91
+ all_mask_REF = []
92
+ all_scenarios = []
93
+ all_bus_numbers = []
94
+
95
+ for batch in predictions:
96
+ all_outputs.append(batch["output"])
97
+ all_mask_PQ.append(batch["mask_PQ"])
98
+ all_mask_PV.append(batch["mask_PV"])
99
+ all_mask_REF.append(batch["mask_REF"])
100
+ all_scenarios.append(batch["scenario_id"])
101
+ all_bus_numbers.append(batch["bus_number"])
102
+
103
+ # Concatenate all
104
+ outputs = np.concatenate(all_outputs, axis=0) # shape: [num_nodes, 6]
105
+ mask_PQ = np.concatenate(all_mask_PQ, axis=0)
106
+ mask_PV = np.concatenate(all_mask_PV, axis=0)
107
+ mask_REF = np.concatenate(all_mask_REF, axis=0)
108
+ scenario_ids = np.concatenate(all_scenarios, axis=0)
109
+ bus_numbers = np.concatenate(all_bus_numbers, axis=0)
110
+
111
+ # Build DataFrame
112
+ df = pd.DataFrame(
113
+ {
114
+ "scenario": scenario_ids,
115
+ "bus": bus_numbers,
116
+ "PD": outputs[:, 0],
117
+ "QD": outputs[:, 1],
118
+ "PG": outputs[:, 2],
119
+ "QG": outputs[:, 3],
120
+ "VM": outputs[:, 4],
121
+ "VA": outputs[:, 5],
122
+ "PQ": mask_PQ.astype(int),
123
+ "PV": mask_PV.astype(int),
124
+ "REF": mask_REF.astype(int),
125
+ },
126
+ )
127
+
128
+ # Save CSV
129
+ output_dir = os.path.join(args.output_path)
130
+ os.makedirs(output_dir, exist_ok=True)
131
+ csv_path = os.path.join(output_dir, "predictions.csv")
132
+ df.to_csv(csv_path, index=False)
133
+
134
+ print(f"Saved predictions to {csv_path}")
@@ -0,0 +1,23 @@
1
+ from gridfm_graphkit.datasets.transforms import (
2
+ AddPFMask,
3
+ AddIdentityMask,
4
+ AddRandomMask,
5
+ AddOPFMask,
6
+ )
7
+ from gridfm_graphkit.datasets.normalizers import (
8
+ Standardizer,
9
+ MinMaxNormalizer,
10
+ BaseMVANormalizer,
11
+ IdentityNormalizer,
12
+ )
13
+
14
+ __all__ = [
15
+ "AddPFMask",
16
+ "AddIdentityMask",
17
+ "AddRandomMask",
18
+ "AddOPFMask",
19
+ "Standardizer",
20
+ "MinMaxNormalizer",
21
+ "BaseMVANormalizer",
22
+ "IdentityNormalizer",
23
+ ]
@@ -1,4 +1,5 @@
1
1
  from gridfm_graphkit.datasets.globals import PD, QD, PG, QG, VA
2
+ from gridfm_graphkit.io.registries import NORMALIZERS_REGISTRY
2
3
  import torch
3
4
  from abc import ABC, abstractmethod
4
5
 
@@ -53,13 +54,25 @@ class Normalizer(ABC):
53
54
  Original tensor.
54
55
  """
55
56
 
57
+ @abstractmethod
58
+ def get_stats(self) -> dict:
59
+ """
60
+ Return the stored normalization statistics for logging/inspection.
61
+ """
62
+
56
63
 
64
+ @NORMALIZERS_REGISTRY.register("minmax")
57
65
  class MinMaxNormalizer(Normalizer):
58
66
  """
59
67
  Scales each feature to the [0, 1] range.
68
+
69
+ Args:
70
+ node_data (bool): Whether data is node-level or edge-level
71
+ args (NestedNamespace): Parameters
72
+
60
73
  """
61
74
 
62
- def __init__(self):
75
+ def __init__(self, node_data: bool, args):
63
76
  self.min_val = None
64
77
  self.max_val = None
65
78
 
@@ -95,13 +108,25 @@ class MinMaxNormalizer(Normalizer):
95
108
  diff[diff == 0] = 1
96
109
  return (normalized_data * diff) + self.min_val
97
110
 
111
+ def get_stats(self) -> dict:
112
+ return {
113
+ "min_value": self.min_val.tolist() if self.min_val is not None else None,
114
+ "max_value": self.max_val.tolist() if self.max_val is not None else None,
115
+ }
98
116
 
117
+
118
+ @NORMALIZERS_REGISTRY.register("standard")
99
119
  class Standardizer(Normalizer):
100
120
  """
101
121
  Standardizes each feature to zero mean and unit variance.
122
+
123
+ Args:
124
+ node_data (bool): Whether data is node-level or edge-level
125
+ args (NestedNamespace): Parameters
126
+
102
127
  """
103
128
 
104
- def __init__(self):
129
+ def __init__(self, node_data: bool, args):
105
130
  self.mean = None
106
131
  self.std = None
107
132
 
@@ -137,7 +162,14 @@ class Standardizer(Normalizer):
137
162
  std[std == 0] = 1
138
163
  return (normalized_data * std) + self.mean
139
164
 
165
+ def get_stats(self) -> dict:
166
+ return {
167
+ "mean": self.mean.tolist() if self.mean is not None else None,
168
+ "std": self.std.tolist() if self.std is not None else None,
169
+ }
170
+
140
171
 
172
+ @NORMALIZERS_REGISTRY.register("baseMVAnorm")
141
173
  class BaseMVANormalizer(Normalizer):
142
174
  """
143
175
  In power systems, a suitable normalization strategy must preserve the physical properties of
@@ -148,14 +180,17 @@ class BaseMVANormalizer(Normalizer):
148
180
  preserving fundamental physical relationships.
149
181
  """
150
182
 
151
- def __init__(self, node_data: bool, baseMVA_orig: float = 100.0):
183
+ def __init__(self, node_data: bool, args):
152
184
  """
153
185
  Args:
154
- node_data: Whether data is node-level or edge-level (PD, QD, PG, QG, VA).
155
- baseMVA_orig: Original baseMVA (e.g. from MATPOWER).
186
+ node_data: Whether data is node-level or edge-level
187
+ args (NestedNamespace): Parameters
188
+
189
+ Attributes:
190
+ baseMVA (float): baseMVA found in casefile. From ``args.data.baseMVA``.
156
191
  """
157
192
  self.node_data = node_data
158
- self.baseMVA_orig = baseMVA_orig
193
+ self.baseMVA_orig = getattr(args.data, "baseMVA", 100)
159
194
  self.baseMVA = None
160
195
 
161
196
  def to(self, device):
@@ -208,12 +243,26 @@ class BaseMVANormalizer(Normalizer):
208
243
 
209
244
  return normalized_data
210
245
 
246
+ def get_stats(self) -> dict:
247
+ return {
248
+ "baseMVA": self.baseMVA,
249
+ "baseMVA_orig": self.baseMVA_orig,
250
+ }
251
+
211
252
 
253
+ @NORMALIZERS_REGISTRY.register("identity")
212
254
  class IdentityNormalizer(Normalizer):
213
255
  """
214
256
  No normalization: returns data unchanged.
257
+
258
+ Args:
259
+ node_data: Whether data is node-level or edge-level
260
+ args (NestedNamespace): Parameters
215
261
  """
216
262
 
263
+ def __init__(self, node_data: bool, args):
264
+ pass
265
+
217
266
  def fit(self, data: torch.Tensor) -> dict:
218
267
  return {}
219
268
 
@@ -225,3 +274,6 @@ class IdentityNormalizer(Normalizer):
225
274
 
226
275
  def inverse_transform(self, normalized_data: torch.Tensor) -> torch.Tensor:
227
276
  return normalized_data
277
+
278
+ def get_stats(self) -> dict:
279
+ return {"note": "No normalization applied."}