gridfm-graphkit 0.0.3__tar.gz → 0.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/PKG-INFO +51 -44
  2. gridfm_graphkit-0.0.4/README.md +141 -0
  3. gridfm_graphkit-0.0.4/gridfm_graphkit/__init__.py +6 -0
  4. gridfm_graphkit-0.0.4/gridfm_graphkit/__main__.py +58 -0
  5. gridfm_graphkit-0.0.4/gridfm_graphkit/cli.py +134 -0
  6. gridfm_graphkit-0.0.4/gridfm_graphkit/datasets/__init__.py +23 -0
  7. gridfm_graphkit-0.0.3/gridfm_graphkit/datasets/data_normalization.py → gridfm_graphkit-0.0.4/gridfm_graphkit/datasets/normalizers.py +58 -6
  8. gridfm_graphkit-0.0.4/gridfm_graphkit/datasets/powergrid_datamodule.py +207 -0
  9. gridfm_graphkit-0.0.3/gridfm_graphkit/datasets/powergrid.py → gridfm_graphkit-0.0.4/gridfm_graphkit/datasets/powergrid_dataset.py +57 -22
  10. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/gridfm_graphkit/datasets/transforms.py +17 -3
  11. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/gridfm_graphkit/datasets/utils.py +0 -16
  12. gridfm_graphkit-0.0.4/gridfm_graphkit/io/param_handler.py +138 -0
  13. gridfm_graphkit-0.0.4/gridfm_graphkit/io/registries.py +42 -0
  14. gridfm_graphkit-0.0.4/gridfm_graphkit/models/__init__.py +4 -0
  15. gridfm_graphkit-0.0.3/gridfm_graphkit/models/graphTransformer.py → gridfm_graphkit-0.0.4/gridfm_graphkit/models/gnn_transformer.py +34 -34
  16. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/gridfm_graphkit/models/gps_transformer.py +40 -43
  17. gridfm_graphkit-0.0.4/gridfm_graphkit/tasks/feature_reconstruction_task.py +366 -0
  18. gridfm_graphkit-0.0.4/gridfm_graphkit/training/callbacks.py +49 -0
  19. {gridfm_graphkit-0.0.3/gridfm_graphkit/utils → gridfm_graphkit-0.0.4/gridfm_graphkit/training}/loss.py +9 -9
  20. gridfm_graphkit-0.0.4/gridfm_graphkit/utils/visualization.py +99 -0
  21. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/gridfm_graphkit.egg-info/PKG-INFO +51 -44
  22. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/gridfm_graphkit.egg-info/SOURCES.txt +12 -9
  23. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/gridfm_graphkit.egg-info/requires.txt +1 -0
  24. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/pyproject.toml +2 -1
  25. gridfm_graphkit-0.0.4/tests/test_data_module.py +60 -0
  26. gridfm_graphkit-0.0.4/tests/test_full_pipeline.py +82 -0
  27. gridfm_graphkit-0.0.4/tests/test_losses.py +36 -0
  28. gridfm_graphkit-0.0.4/tests/test_model_outputs.py +69 -0
  29. gridfm_graphkit-0.0.4/tests/test_normalization.py +36 -0
  30. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/tests/test_yaml_configs.py +23 -0
  31. gridfm_graphkit-0.0.3/README.md +0 -135
  32. gridfm_graphkit-0.0.3/gridfm_graphkit/__main__.py +0 -62
  33. gridfm_graphkit-0.0.3/gridfm_graphkit/cli.py +0 -530
  34. gridfm_graphkit-0.0.3/gridfm_graphkit/evaluation/node_level.py +0 -334
  35. gridfm_graphkit-0.0.3/gridfm_graphkit/io/param_handler.py +0 -293
  36. gridfm_graphkit-0.0.3/gridfm_graphkit/models/__init__.py +0 -0
  37. gridfm_graphkit-0.0.3/gridfm_graphkit/training/__init__.py +0 -0
  38. gridfm_graphkit-0.0.3/gridfm_graphkit/training/callbacks.py +0 -47
  39. gridfm_graphkit-0.0.3/gridfm_graphkit/training/plugins.py +0 -218
  40. gridfm_graphkit-0.0.3/gridfm_graphkit/training/trainer.py +0 -156
  41. gridfm_graphkit-0.0.3/gridfm_graphkit/utils/__init__.py +0 -0
  42. gridfm_graphkit-0.0.3/gridfm_graphkit/utils/visualization.py +0 -324
  43. gridfm_graphkit-0.0.3/tests/test_model_outputs.py +0 -55
  44. gridfm_graphkit-0.0.3/tests/test_training.py +0 -90
  45. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/LICENSE +0 -0
  46. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/gridfm_graphkit/datasets/globals.py +0 -0
  47. {gridfm_graphkit-0.0.3/gridfm_graphkit → gridfm_graphkit-0.0.4/gridfm_graphkit/io}/__init__.py +0 -0
  48. {gridfm_graphkit-0.0.3/gridfm_graphkit/datasets → gridfm_graphkit-0.0.4/gridfm_graphkit/tasks}/__init__.py +0 -0
  49. {gridfm_graphkit-0.0.3/gridfm_graphkit/evaluation → gridfm_graphkit-0.0.4/gridfm_graphkit/training}/__init__.py +0 -0
  50. {gridfm_graphkit-0.0.3/gridfm_graphkit/io → gridfm_graphkit-0.0.4/gridfm_graphkit/utils}/__init__.py +0 -0
  51. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/gridfm_graphkit.egg-info/dependency_links.txt +0 -0
  52. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/gridfm_graphkit.egg-info/entry_points.txt +0 -0
  53. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/gridfm_graphkit.egg-info/top_level.txt +0 -0
  54. {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gridfm-graphkit
3
- Version: 0.0.3
3
+ Version: 0.0.4
4
4
  Summary: Grid Foundation Model
5
5
  Author-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>, Alban Puech <apuech@seas.harvard.edu>, Tamara Govindasamy <tamara.govindasamy@ibm.com>, Mangaliso Mngomezulu <mngomezulum@ibm.com>, Etienne Vos <etienne.vos@ibm.com>, Celia Cintas <celia.cintas@ibm.com>, Jonas Weiss <jwe@zurich.ibm.com>
6
6
  Maintainer-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>
@@ -25,6 +25,7 @@ Requires-Dist: torch>=2.7.1
25
25
  Requires-Dist: torch-geometric>=2.6.1
26
26
  Requires-Dist: torchaudio>=2.7.1
27
27
  Requires-Dist: torchvision>=0.22.1
28
+ Requires-Dist: lightning
28
29
  Provides-Extra: dev
29
30
  Requires-Dist: mkdocs-material; extra == "dev"
30
31
  Requires-Dist: mkdocstrings[python]; extra == "dev"
@@ -38,6 +39,9 @@ Dynamic: license-file
38
39
 
39
40
  # gridfm-graphkit
40
41
  [![Docs](https://img.shields.io/badge/docs-available-brightgreen)](https://gridfm.github.io/gridfm-graphkit/)
42
+ ![Coverage](https://img.shields.io/badge/coverage-83%25-yellowgreen)
43
+ ![Python](https://img.shields.io/badge/python-3.10%20%E2%80%93%203.12-blue)
44
+ ![License](https://img.shields.io/badge/license-Apache%202.0-blue)
41
45
 
42
46
  This library is brought to you by the GridFM team to train, finetune and interact with a foundation model for the electric power grid.
43
47
 
@@ -73,7 +77,7 @@ pip install -e .[dev,test]
73
77
  ```
74
78
 
75
79
 
76
- # gridfm-graphkit CLI
80
+ # CLI commands
77
81
 
78
82
  An interface to train, fine-tune, and evaluate GridFM models using configurable YAML files and MLflow tracking.
79
83
 
@@ -83,9 +87,10 @@ gridfm_graphkit <command> [OPTIONS]
83
87
 
84
88
  Available commands:
85
89
 
86
- * `train` – Train a new model
87
- * `predict` – Evaluate an existing model
88
- * `finetune` – Fine-tune a pre-trained model
90
+ * `train` – Train a new model from scrathc
91
+ * `finetune` – Fine-tune an existing pre-trained model
92
+ * `evaluate` – Evaluate model performance on a dataset
93
+ * `predict` – Run inference and save predictions
89
94
 
90
95
  ---
91
96
 
@@ -99,75 +104,77 @@ gridfm_graphkit train --config path/to/config.yaml
99
104
 
100
105
  | Argument | Type | Description | Default |
101
106
  | ---------------- | ------ | ---------------------------------------------------------------- | ------- |
102
- | `--config` | `str` | **Required for standard training**. Path to base config YAML. | `None` |
103
- | `--grid` | `str` | **Optional**. Path to grid search YAML. Not supported with `-c`. | `None` |
104
- | `--exp` | `str` | **Optional**. MLflow experiment name. Defaults to a timestamp. | `None` |
107
+ | `--config` | `str` | **Required**. Path to the training configuration YAML file. | `None` |
108
+ | `--exp_name` | `str` | **Optional**. MLflow experiment name. | `timestamp` |
109
+ | `--run_name` | `str` | **Optional**. MLflow run name. | `run` |
110
+ | `--log_dir ` | `str` | **Optional**. MLflow logging directory. | `mlruns` |
105
111
  | `--data_path` | `str` | **Optional**. Root dataset directory. | `data` |
106
- | `-c` | `flag` | **Optional**. Enable checkpoint mode. | `False` |
107
- | `--model_exp_id` | `str` | **Required if `-c` is used**. MLflow experiment ID. | `None` |
108
- | `--model_run_id` | `str` | **Required if `-c` is used**. MLflow run ID. | `None` |
109
112
 
110
113
  ### Examples
111
114
 
112
115
  **Standard Training:**
113
116
 
114
117
  ```bash
115
- gridfm_graphkit train --config config/train.yaml --exp "run1"
118
+ gridfm_graphkit train --config examples/config/case30_ieee_base.yaml --data_path examples/data
116
119
  ```
117
120
 
118
- **Grid Search Training:**
121
+ ---
122
+
123
+ ## Fine-Tuning Models
119
124
 
120
125
  ```bash
121
- gridfm_graphkit train --config config/train.yaml --grid config/grid.yaml
126
+ gridfm_graphkit finetune --config path/to/config.yaml --model_path path/to/model.pth
122
127
  ```
123
128
 
124
- **Training from Checkpoint:**
129
+ ### Arguments
130
+
131
+ | Argument | Type | Description | Default |
132
+ | -------------- | ----- | ----------------------------------------------- | --------- |
133
+ | `--config` | `str` | **Required**. Fine-tuning configuration file. | `None` |
134
+ | `--model_path` | `str` | **Required**. Path to a pre-trained model file. | `None` |
135
+ | `--exp_name` | `str` | MLflow experiment name. | timestamp |
136
+ | `--run_name` | `str` | MLflow run name. | `run` |
137
+ | `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
138
+ | `--data_path` | `str` | Root dataset directory. | `data` |
125
139
 
126
- ```bash
127
- gridfm_graphkit train -c --model_exp_id 123 --model_run_id abc
128
- ```
129
140
 
130
141
  ---
131
142
 
132
143
  ## Evaluating Models
133
144
 
134
145
  ```bash
135
- gridfm_graphkit predict --model_path model.pth --config config/eval.yaml --eval_name run_eval
146
+ gridfm_graphkit evaluate --config path/to/eval.yaml --model_path path/to/model.pth
136
147
  ```
137
148
 
138
149
  ### Arguments
139
150
 
140
- | Argument | Type | Description | Default |
141
- | ---------------- | ----- | ----------------------------------------------------------------- | ------------ |
142
- | `--model_path` | `str` | **Optional**. Path to a model file. | `None` |
143
- | `--model_exp_id` | `str` | **Required if `--model_path` is not used**. MLflow experiment ID. | `None` |
144
- | `--model_run_id` | `str` | **Required if `--model_path` is not used**. MLflow run ID. | `None` |
145
- | `--model_name` | `str` | **Optional**. Filename inside MLflow artifacts. | `best_model` |
146
- | `--config` | `str` | **Required**. Path to evaluation config. | `None` |
147
- | `--eval_name` | `str` | **Required**. Name of the evaluation run in MLflow. | `None` |
148
- | `--data_path` | `str` | **Optional**. Path to dataset directory. | `data` |
149
-
150
- ### Examples
151
-
152
- **Evaluate a Logged MLflow Model:**
153
-
154
- ```bash
155
- gridfm_graphkit predict --config config/eval.yaml --eval_name run_eval --model_exp_id 1 --model_run_id abc
156
- ```
151
+ | Argument | Type | Description | Default |
152
+ | -------------- | ----- | ---------------------------------------- | --------- |
153
+ | `--config` | `str` | **Required**. Path to evaluation config. | `None` |
154
+ | `--model_path` | `str` | Path to the trained model file. | `None` |
155
+ | `--exp_name` | `str` | MLflow experiment name. | timestamp |
156
+ | `--run_name` | `str` | MLflow run name. | `run` |
157
+ | `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
158
+ | `--data_path` | `str` | Dataset directory. | `data` |
157
159
 
158
160
  ---
159
161
 
160
- ## Fine-Tuning Models
162
+ ## Running Predictions
161
163
 
162
164
  ```bash
163
- gridfm_graphkit finetune --config path/to/config.yaml --model_path path/to/model.pth
165
+ gridfm_graphkit predict --config path/to/config.yaml --model_path path/to/model.pth
164
166
  ```
165
167
 
166
168
  ### Arguments
167
169
 
168
- | Argument | Type | Description | Default |
169
- | -------------- | ----- | ----------------------------------------------- | ------- |
170
- | `--config` | `str` | **Required**. Fine-tuning configuration file. | `None` |
171
- | `--model_path` | `str` | **Required**. Path to a pre-trained model file. | `None` |
172
- | `--exp` | `str` | **Optional**. MLflow experiment name. | `None` |
173
- | `--data_path` | `str` | **Optional**. Root dataset directory. | `data` |
170
+ | Argument | Type | Description | Default |
171
+ | --------------- | ----- | --------------------------------------------- | --------- |
172
+ | `--config` | `str` | **Required**. Path to prediction config file. | `None` |
173
+ | `--model_path` | `str` | Path to the trained model file. | `None` |
174
+ | `--exp_name` | `str` | MLflow experiment name. | timestamp |
175
+ | `--run_name` | `str` | MLflow run name. | `run` |
176
+ | `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
177
+ | `--data_path` | `str` | Dataset directory. | `data` |
178
+ | `--output_path` | `str` | Directory where predictions are saved. | `data` |
179
+
180
+ ---
@@ -0,0 +1,141 @@
1
+ # gridfm-graphkit
2
+ [![Docs](https://img.shields.io/badge/docs-available-brightgreen)](https://gridfm.github.io/gridfm-graphkit/)
3
+ ![Coverage](https://img.shields.io/badge/coverage-83%25-yellowgreen)
4
+ ![Python](https://img.shields.io/badge/python-3.10%20%E2%80%93%203.12-blue)
5
+ ![License](https://img.shields.io/badge/license-Apache%202.0-blue)
6
+
7
+ This library is brought to you by the GridFM team to train, finetune and interact with a foundation model for the electric power grid.
8
+
9
+ ---
10
+
11
+ <p align="center">
12
+ <img src="https://raw.githubusercontent.com/gridfm/gridfm-graphkit/refs/heads/main/docs/figs/pre_training.png" alt="GridFM logo"/>
13
+ <br/>
14
+ </p>
15
+
16
+ # Installation
17
+
18
+ You can install `gridfm-graphkit` directly from PyPI:
19
+
20
+ ```bash
21
+ pip install gridfm-graphkit
22
+ ```
23
+
24
+ To contribute or develop locally, clone the repository and install in editable mode:
25
+
26
+ ```bash
27
+ git clone git@github.com:gridfm/gridfm-graphkit.git
28
+ cd gridfm-graphkit
29
+ python -m venv venv
30
+ source venv/bin/activate
31
+ pip install -e .
32
+ ```
33
+
34
+ For documentation generation and unit testing, install with the optional `dev` and `test` extras:
35
+
36
+ ```bash
37
+ pip install -e .[dev,test]
38
+ ```
39
+
40
+
41
+ # CLI commands
42
+
43
+ An interface to train, fine-tune, and evaluate GridFM models using configurable YAML files and MLflow tracking.
44
+
45
+ ```bash
46
+ gridfm_graphkit <command> [OPTIONS]
47
+ ```
48
+
49
+ Available commands:
50
+
51
+ * `train` – Train a new model from scrathc
52
+ * `finetune` – Fine-tune an existing pre-trained model
53
+ * `evaluate` – Evaluate model performance on a dataset
54
+ * `predict` – Run inference and save predictions
55
+
56
+ ---
57
+
58
+ ## Training Models
59
+
60
+ ```bash
61
+ gridfm_graphkit train --config path/to/config.yaml
62
+ ```
63
+
64
+ ### Arguments
65
+
66
+ | Argument | Type | Description | Default |
67
+ | ---------------- | ------ | ---------------------------------------------------------------- | ------- |
68
+ | `--config` | `str` | **Required**. Path to the training configuration YAML file. | `None` |
69
+ | `--exp_name` | `str` | **Optional**. MLflow experiment name. | `timestamp` |
70
+ | `--run_name` | `str` | **Optional**. MLflow run name. | `run` |
71
+ | `--log_dir ` | `str` | **Optional**. MLflow logging directory. | `mlruns` |
72
+ | `--data_path` | `str` | **Optional**. Root dataset directory. | `data` |
73
+
74
+ ### Examples
75
+
76
+ **Standard Training:**
77
+
78
+ ```bash
79
+ gridfm_graphkit train --config examples/config/case30_ieee_base.yaml --data_path examples/data
80
+ ```
81
+
82
+ ---
83
+
84
+ ## Fine-Tuning Models
85
+
86
+ ```bash
87
+ gridfm_graphkit finetune --config path/to/config.yaml --model_path path/to/model.pth
88
+ ```
89
+
90
+ ### Arguments
91
+
92
+ | Argument | Type | Description | Default |
93
+ | -------------- | ----- | ----------------------------------------------- | --------- |
94
+ | `--config` | `str` | **Required**. Fine-tuning configuration file. | `None` |
95
+ | `--model_path` | `str` | **Required**. Path to a pre-trained model file. | `None` |
96
+ | `--exp_name` | `str` | MLflow experiment name. | timestamp |
97
+ | `--run_name` | `str` | MLflow run name. | `run` |
98
+ | `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
99
+ | `--data_path` | `str` | Root dataset directory. | `data` |
100
+
101
+
102
+ ---
103
+
104
+ ## Evaluating Models
105
+
106
+ ```bash
107
+ gridfm_graphkit evaluate --config path/to/eval.yaml --model_path path/to/model.pth
108
+ ```
109
+
110
+ ### Arguments
111
+
112
+ | Argument | Type | Description | Default |
113
+ | -------------- | ----- | ---------------------------------------- | --------- |
114
+ | `--config` | `str` | **Required**. Path to evaluation config. | `None` |
115
+ | `--model_path` | `str` | Path to the trained model file. | `None` |
116
+ | `--exp_name` | `str` | MLflow experiment name. | timestamp |
117
+ | `--run_name` | `str` | MLflow run name. | `run` |
118
+ | `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
119
+ | `--data_path` | `str` | Dataset directory. | `data` |
120
+
121
+ ---
122
+
123
+ ## Running Predictions
124
+
125
+ ```bash
126
+ gridfm_graphkit predict --config path/to/config.yaml --model_path path/to/model.pth
127
+ ```
128
+
129
+ ### Arguments
130
+
131
+ | Argument | Type | Description | Default |
132
+ | --------------- | ----- | --------------------------------------------- | --------- |
133
+ | `--config` | `str` | **Required**. Path to prediction config file. | `None` |
134
+ | `--model_path` | `str` | Path to the trained model file. | `None` |
135
+ | `--exp_name` | `str` | MLflow experiment name. | timestamp |
136
+ | `--run_name` | `str` | MLflow run name. | `run` |
137
+ | `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
138
+ | `--data_path` | `str` | Dataset directory. | `data` |
139
+ | `--output_path` | `str` | Directory where predictions are saved. | `data` |
140
+
141
+ ---
@@ -0,0 +1,6 @@
1
+ import gridfm_graphkit.datasets
2
+ import gridfm_graphkit.models
3
+
4
+ __all__ = [
5
+ "gridfm_graphkit",
6
+ ]
@@ -0,0 +1,58 @@
1
+ import argparse
2
+ from datetime import datetime
3
+ from gridfm_graphkit.cli import main_cli
4
+
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser(
8
+ prog="gridfm_graphkit",
9
+ description="gridfm-graphkit CLI",
10
+ )
11
+ subparsers = parser.add_subparsers(dest="command", required=True)
12
+ exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
13
+
14
+ # ---- TRAIN SUBCOMMAND ----
15
+ train_parser = subparsers.add_parser("train", help="Run training")
16
+ train_parser.add_argument("--config", type=str, required=True)
17
+ train_parser.add_argument("--exp_name", type=str, default=exp_name)
18
+ train_parser.add_argument("--run_name", type=str, default="run")
19
+ train_parser.add_argument("--log_dir", type=str, default="mlruns")
20
+ train_parser.add_argument("--data_path", type=str, default="data")
21
+
22
+ # ---- FINETUNE SUBCOMMAND ----
23
+ finetune_parser = subparsers.add_parser("finetune", help="Run fine-tuning")
24
+ finetune_parser.add_argument("--config", type=str, required=True)
25
+ finetune_parser.add_argument("--model_path", type=str, required=True)
26
+ finetune_parser.add_argument("--exp_name", type=str, default=exp_name)
27
+ finetune_parser.add_argument("--run_name", type=str, default="run")
28
+ finetune_parser.add_argument("--log_dir", type=str, default="mlruns")
29
+ finetune_parser.add_argument("--data_path", type=str, default="data")
30
+
31
+ # ---- EVALUATE SUBCOMMAND ----
32
+ evaluate_parser = subparsers.add_parser(
33
+ "evaluate",
34
+ help="Evaluate model performance",
35
+ )
36
+ evaluate_parser.add_argument("--model_path", type=str, default=None)
37
+ evaluate_parser.add_argument("--config", type=str, required=True)
38
+ evaluate_parser.add_argument("--exp_name", type=str, default=exp_name)
39
+ evaluate_parser.add_argument("--run_name", type=str, default="run")
40
+ evaluate_parser.add_argument("--log_dir", type=str, default="mlruns")
41
+ evaluate_parser.add_argument("--data_path", type=str, default="data")
42
+
43
+ # ---- PREDICT SUBCOMMAND ----
44
+ predict_parser = subparsers.add_parser("predict", help="Evaluate model performance")
45
+ predict_parser.add_argument("--model_path", type=str, required=None)
46
+ predict_parser.add_argument("--config", type=str, required=True)
47
+ predict_parser.add_argument("--exp_name", type=str, default=exp_name)
48
+ predict_parser.add_argument("--run_name", type=str, default="run")
49
+ predict_parser.add_argument("--log_dir", type=str, default="mlruns")
50
+ predict_parser.add_argument("--data_path", type=str, default="data")
51
+ predict_parser.add_argument("--output_path", type=str, default="data")
52
+
53
+ args = parser.parse_args()
54
+ main_cli(args)
55
+
56
+
57
+ if __name__ == "__main__":
58
+ main()
@@ -0,0 +1,134 @@
1
+ from gridfm_graphkit.datasets.powergrid_datamodule import LitGridDataModule
2
+ from gridfm_graphkit.io.param_handler import NestedNamespace
3
+ from gridfm_graphkit.training.callbacks import SaveBestModelStateDict
4
+ import numpy as np
5
+ import os
6
+ import yaml
7
+ import torch
8
+ import random
9
+ import pandas as pd
10
+
11
+ from gridfm_graphkit.tasks.feature_reconstruction_task import FeatureReconstructionTask
12
+ from lightning.pytorch.callbacks.early_stopping import EarlyStopping
13
+ from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
14
+ from lightning.pytorch.loggers import MLFlowLogger
15
+ import lightning as L
16
+
17
+
18
+ def get_training_callbacks(args):
19
+ early_stop_callback = EarlyStopping(
20
+ monitor="Validation loss",
21
+ min_delta=args.callbacks.tol,
22
+ patience=args.callbacks.patience,
23
+ verbose=False,
24
+ mode="min",
25
+ )
26
+
27
+ save_best_model_callback = SaveBestModelStateDict(
28
+ monitor="Validation loss",
29
+ mode="min",
30
+ filename="best_model_state_dict.pt",
31
+ )
32
+
33
+ checkpoint_callback = ModelCheckpoint(
34
+ monitor="Validation loss", # or whichever metric you track
35
+ mode="min",
36
+ save_last=True,
37
+ save_top_k=0,
38
+ )
39
+
40
+ return [early_stop_callback, save_best_model_callback, checkpoint_callback]
41
+
42
+
43
+ def main_cli(args):
44
+ logger = MLFlowLogger(
45
+ save_dir=args.log_dir,
46
+ experiment_name=args.exp_name,
47
+ run_name=args.run_name,
48
+ )
49
+
50
+ with open(args.config, "r") as f:
51
+ base_config = yaml.safe_load(f)
52
+
53
+ config_args = NestedNamespace(**base_config)
54
+
55
+ torch.manual_seed(config_args.seed)
56
+ random.seed(config_args.seed)
57
+ np.random.seed(config_args.seed)
58
+
59
+ litGrid = LitGridDataModule(config_args, args.data_path)
60
+ model = FeatureReconstructionTask(
61
+ config_args,
62
+ litGrid.node_normalizers,
63
+ litGrid.edge_normalizers,
64
+ )
65
+ if args.command != "train":
66
+ print(f"Loading model weights from {args.model_path}")
67
+ state_dict = torch.load(args.model_path)
68
+ model.load_state_dict(state_dict)
69
+
70
+ trainer = L.Trainer(
71
+ logger=logger,
72
+ accelerator=config_args.training.accelerator,
73
+ devices=config_args.training.devices,
74
+ strategy=config_args.training.strategy,
75
+ log_every_n_steps=1,
76
+ default_root_dir=args.log_dir,
77
+ max_epochs=config_args.training.epochs,
78
+ callbacks=get_training_callbacks(config_args),
79
+ )
80
+ if args.command == "train" or args.command == "finetune":
81
+ trainer.fit(model=model, datamodule=litGrid)
82
+
83
+ if args.command != "predict":
84
+ trainer.test(model=model, datamodule=litGrid)
85
+
86
+ if args.command == "predict":
87
+ predictions = trainer.predict(model=model, datamodule=litGrid)
88
+ all_outputs = []
89
+ all_mask_PQ = []
90
+ all_mask_PV = []
91
+ all_mask_REF = []
92
+ all_scenarios = []
93
+ all_bus_numbers = []
94
+
95
+ for batch in predictions:
96
+ all_outputs.append(batch["output"])
97
+ all_mask_PQ.append(batch["mask_PQ"])
98
+ all_mask_PV.append(batch["mask_PV"])
99
+ all_mask_REF.append(batch["mask_REF"])
100
+ all_scenarios.append(batch["scenario_id"])
101
+ all_bus_numbers.append(batch["bus_number"])
102
+
103
+ # Concatenate all
104
+ outputs = np.concatenate(all_outputs, axis=0) # shape: [num_nodes, 6]
105
+ mask_PQ = np.concatenate(all_mask_PQ, axis=0)
106
+ mask_PV = np.concatenate(all_mask_PV, axis=0)
107
+ mask_REF = np.concatenate(all_mask_REF, axis=0)
108
+ scenario_ids = np.concatenate(all_scenarios, axis=0)
109
+ bus_numbers = np.concatenate(all_bus_numbers, axis=0)
110
+
111
+ # Build DataFrame
112
+ df = pd.DataFrame(
113
+ {
114
+ "scenario": scenario_ids,
115
+ "bus": bus_numbers,
116
+ "PD": outputs[:, 0],
117
+ "QD": outputs[:, 1],
118
+ "PG": outputs[:, 2],
119
+ "QG": outputs[:, 3],
120
+ "VM": outputs[:, 4],
121
+ "VA": outputs[:, 5],
122
+ "PQ": mask_PQ.astype(int),
123
+ "PV": mask_PV.astype(int),
124
+ "REF": mask_REF.astype(int),
125
+ },
126
+ )
127
+
128
+ # Save CSV
129
+ output_dir = os.path.join(args.output_path)
130
+ os.makedirs(output_dir, exist_ok=True)
131
+ csv_path = os.path.join(output_dir, "predictions.csv")
132
+ df.to_csv(csv_path, index=False)
133
+
134
+ print(f"Saved predictions to {csv_path}")
@@ -0,0 +1,23 @@
1
+ from gridfm_graphkit.datasets.transforms import (
2
+ AddPFMask,
3
+ AddIdentityMask,
4
+ AddRandomMask,
5
+ AddOPFMask,
6
+ )
7
+ from gridfm_graphkit.datasets.normalizers import (
8
+ Standardizer,
9
+ MinMaxNormalizer,
10
+ BaseMVANormalizer,
11
+ IdentityNormalizer,
12
+ )
13
+
14
+ __all__ = [
15
+ "AddPFMask",
16
+ "AddIdentityMask",
17
+ "AddRandomMask",
18
+ "AddOPFMask",
19
+ "Standardizer",
20
+ "MinMaxNormalizer",
21
+ "BaseMVANormalizer",
22
+ "IdentityNormalizer",
23
+ ]
@@ -1,4 +1,5 @@
1
1
  from gridfm_graphkit.datasets.globals import PD, QD, PG, QG, VA
2
+ from gridfm_graphkit.io.registries import NORMALIZERS_REGISTRY
2
3
  import torch
3
4
  from abc import ABC, abstractmethod
4
5
 
@@ -53,13 +54,25 @@ class Normalizer(ABC):
53
54
  Original tensor.
54
55
  """
55
56
 
57
+ @abstractmethod
58
+ def get_stats(self) -> dict:
59
+ """
60
+ Return the stored normalization statistics for logging/inspection.
61
+ """
62
+
56
63
 
64
+ @NORMALIZERS_REGISTRY.register("minmax")
57
65
  class MinMaxNormalizer(Normalizer):
58
66
  """
59
67
  Scales each feature to the [0, 1] range.
68
+
69
+ Args:
70
+ node_data (bool): Whether data is node-level or edge-level
71
+ args (NestedNamespace): Parameters
72
+
60
73
  """
61
74
 
62
- def __init__(self):
75
+ def __init__(self, node_data: bool, args):
63
76
  self.min_val = None
64
77
  self.max_val = None
65
78
 
@@ -95,13 +108,25 @@ class MinMaxNormalizer(Normalizer):
95
108
  diff[diff == 0] = 1
96
109
  return (normalized_data * diff) + self.min_val
97
110
 
111
+ def get_stats(self) -> dict:
112
+ return {
113
+ "min_value": self.min_val.tolist() if self.min_val is not None else None,
114
+ "max_value": self.max_val.tolist() if self.max_val is not None else None,
115
+ }
98
116
 
117
+
118
+ @NORMALIZERS_REGISTRY.register("standard")
99
119
  class Standardizer(Normalizer):
100
120
  """
101
121
  Standardizes each feature to zero mean and unit variance.
122
+
123
+ Args:
124
+ node_data (bool): Whether data is node-level or edge-level
125
+ args (NestedNamespace): Parameters
126
+
102
127
  """
103
128
 
104
- def __init__(self):
129
+ def __init__(self, node_data: bool, args):
105
130
  self.mean = None
106
131
  self.std = None
107
132
 
@@ -137,7 +162,14 @@ class Standardizer(Normalizer):
137
162
  std[std == 0] = 1
138
163
  return (normalized_data * std) + self.mean
139
164
 
165
+ def get_stats(self) -> dict:
166
+ return {
167
+ "mean": self.mean.tolist() if self.mean is not None else None,
168
+ "std": self.std.tolist() if self.std is not None else None,
169
+ }
170
+
140
171
 
172
+ @NORMALIZERS_REGISTRY.register("baseMVAnorm")
141
173
  class BaseMVANormalizer(Normalizer):
142
174
  """
143
175
  In power systems, a suitable normalization strategy must preserve the physical properties of
@@ -148,14 +180,17 @@ class BaseMVANormalizer(Normalizer):
148
180
  preserving fundamental physical relationships.
149
181
  """
150
182
 
151
- def __init__(self, node_data: bool, baseMVA_orig: float = 100.0):
183
+ def __init__(self, node_data: bool, args):
152
184
  """
153
185
  Args:
154
- node_data: Whether data is node-level or edge-level (PD, QD, PG, QG, VA).
155
- baseMVA_orig: Original baseMVA (e.g. from MATPOWER).
186
+ node_data: Whether data is node-level or edge-level
187
+ args (NestedNamespace): Parameters
188
+
189
+ Attributes:
190
+ baseMVA (float): baseMVA found in casefile. From ``args.data.baseMVA``.
156
191
  """
157
192
  self.node_data = node_data
158
- self.baseMVA_orig = baseMVA_orig
193
+ self.baseMVA_orig = getattr(args.data, "baseMVA", 100)
159
194
  self.baseMVA = None
160
195
 
161
196
  def to(self, device):
@@ -208,12 +243,26 @@ class BaseMVANormalizer(Normalizer):
208
243
 
209
244
  return normalized_data
210
245
 
246
+ def get_stats(self) -> dict:
247
+ return {
248
+ "baseMVA": self.baseMVA,
249
+ "baseMVA_orig": self.baseMVA_orig,
250
+ }
251
+
211
252
 
253
+ @NORMALIZERS_REGISTRY.register("identity")
212
254
  class IdentityNormalizer(Normalizer):
213
255
  """
214
256
  No normalization: returns data unchanged.
257
+
258
+ Args:
259
+ node_data: Whether data is node-level or edge-level
260
+ args (NestedNamespace): Parameters
215
261
  """
216
262
 
263
+ def __init__(self, node_data: bool, args):
264
+ pass
265
+
217
266
  def fit(self, data: torch.Tensor) -> dict:
218
267
  return {}
219
268
 
@@ -225,3 +274,6 @@ class IdentityNormalizer(Normalizer):
225
274
 
226
275
  def inverse_transform(self, normalized_data: torch.Tensor) -> torch.Tensor:
227
276
  return normalized_data
277
+
278
+ def get_stats(self) -> dict:
279
+ return {"note": "No normalization applied."}