gridfm-graphkit 0.0.3__tar.gz → 0.0.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/PKG-INFO +51 -44
- gridfm_graphkit-0.0.4/README.md +141 -0
- gridfm_graphkit-0.0.4/gridfm_graphkit/__init__.py +6 -0
- gridfm_graphkit-0.0.4/gridfm_graphkit/__main__.py +58 -0
- gridfm_graphkit-0.0.4/gridfm_graphkit/cli.py +134 -0
- gridfm_graphkit-0.0.4/gridfm_graphkit/datasets/__init__.py +23 -0
- gridfm_graphkit-0.0.3/gridfm_graphkit/datasets/data_normalization.py → gridfm_graphkit-0.0.4/gridfm_graphkit/datasets/normalizers.py +58 -6
- gridfm_graphkit-0.0.4/gridfm_graphkit/datasets/powergrid_datamodule.py +207 -0
- gridfm_graphkit-0.0.3/gridfm_graphkit/datasets/powergrid.py → gridfm_graphkit-0.0.4/gridfm_graphkit/datasets/powergrid_dataset.py +57 -22
- {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/gridfm_graphkit/datasets/transforms.py +17 -3
- {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/gridfm_graphkit/datasets/utils.py +0 -16
- gridfm_graphkit-0.0.4/gridfm_graphkit/io/param_handler.py +138 -0
- gridfm_graphkit-0.0.4/gridfm_graphkit/io/registries.py +42 -0
- gridfm_graphkit-0.0.4/gridfm_graphkit/models/__init__.py +4 -0
- gridfm_graphkit-0.0.3/gridfm_graphkit/models/graphTransformer.py → gridfm_graphkit-0.0.4/gridfm_graphkit/models/gnn_transformer.py +34 -34
- {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/gridfm_graphkit/models/gps_transformer.py +40 -43
- gridfm_graphkit-0.0.4/gridfm_graphkit/tasks/feature_reconstruction_task.py +366 -0
- gridfm_graphkit-0.0.4/gridfm_graphkit/training/callbacks.py +49 -0
- {gridfm_graphkit-0.0.3/gridfm_graphkit/utils → gridfm_graphkit-0.0.4/gridfm_graphkit/training}/loss.py +9 -9
- gridfm_graphkit-0.0.4/gridfm_graphkit/utils/visualization.py +99 -0
- {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/gridfm_graphkit.egg-info/PKG-INFO +51 -44
- {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/gridfm_graphkit.egg-info/SOURCES.txt +12 -9
- {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/gridfm_graphkit.egg-info/requires.txt +1 -0
- {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/pyproject.toml +2 -1
- gridfm_graphkit-0.0.4/tests/test_data_module.py +60 -0
- gridfm_graphkit-0.0.4/tests/test_full_pipeline.py +82 -0
- gridfm_graphkit-0.0.4/tests/test_losses.py +36 -0
- gridfm_graphkit-0.0.4/tests/test_model_outputs.py +69 -0
- gridfm_graphkit-0.0.4/tests/test_normalization.py +36 -0
- {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/tests/test_yaml_configs.py +23 -0
- gridfm_graphkit-0.0.3/README.md +0 -135
- gridfm_graphkit-0.0.3/gridfm_graphkit/__main__.py +0 -62
- gridfm_graphkit-0.0.3/gridfm_graphkit/cli.py +0 -530
- gridfm_graphkit-0.0.3/gridfm_graphkit/evaluation/node_level.py +0 -334
- gridfm_graphkit-0.0.3/gridfm_graphkit/io/param_handler.py +0 -293
- gridfm_graphkit-0.0.3/gridfm_graphkit/models/__init__.py +0 -0
- gridfm_graphkit-0.0.3/gridfm_graphkit/training/__init__.py +0 -0
- gridfm_graphkit-0.0.3/gridfm_graphkit/training/callbacks.py +0 -47
- gridfm_graphkit-0.0.3/gridfm_graphkit/training/plugins.py +0 -218
- gridfm_graphkit-0.0.3/gridfm_graphkit/training/trainer.py +0 -156
- gridfm_graphkit-0.0.3/gridfm_graphkit/utils/__init__.py +0 -0
- gridfm_graphkit-0.0.3/gridfm_graphkit/utils/visualization.py +0 -324
- gridfm_graphkit-0.0.3/tests/test_model_outputs.py +0 -55
- gridfm_graphkit-0.0.3/tests/test_training.py +0 -90
- {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/LICENSE +0 -0
- {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/gridfm_graphkit/datasets/globals.py +0 -0
- {gridfm_graphkit-0.0.3/gridfm_graphkit → gridfm_graphkit-0.0.4/gridfm_graphkit/io}/__init__.py +0 -0
- {gridfm_graphkit-0.0.3/gridfm_graphkit/datasets → gridfm_graphkit-0.0.4/gridfm_graphkit/tasks}/__init__.py +0 -0
- {gridfm_graphkit-0.0.3/gridfm_graphkit/evaluation → gridfm_graphkit-0.0.4/gridfm_graphkit/training}/__init__.py +0 -0
- {gridfm_graphkit-0.0.3/gridfm_graphkit/io → gridfm_graphkit-0.0.4/gridfm_graphkit/utils}/__init__.py +0 -0
- {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/gridfm_graphkit.egg-info/dependency_links.txt +0 -0
- {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/gridfm_graphkit.egg-info/entry_points.txt +0 -0
- {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/gridfm_graphkit.egg-info/top_level.txt +0 -0
- {gridfm_graphkit-0.0.3 → gridfm_graphkit-0.0.4}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gridfm-graphkit
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.4
|
|
4
4
|
Summary: Grid Foundation Model
|
|
5
5
|
Author-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>, Alban Puech <apuech@seas.harvard.edu>, Tamara Govindasamy <tamara.govindasamy@ibm.com>, Mangaliso Mngomezulu <mngomezulum@ibm.com>, Etienne Vos <etienne.vos@ibm.com>, Celia Cintas <celia.cintas@ibm.com>, Jonas Weiss <jwe@zurich.ibm.com>
|
|
6
6
|
Maintainer-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>
|
|
@@ -25,6 +25,7 @@ Requires-Dist: torch>=2.7.1
|
|
|
25
25
|
Requires-Dist: torch-geometric>=2.6.1
|
|
26
26
|
Requires-Dist: torchaudio>=2.7.1
|
|
27
27
|
Requires-Dist: torchvision>=0.22.1
|
|
28
|
+
Requires-Dist: lightning
|
|
28
29
|
Provides-Extra: dev
|
|
29
30
|
Requires-Dist: mkdocs-material; extra == "dev"
|
|
30
31
|
Requires-Dist: mkdocstrings[python]; extra == "dev"
|
|
@@ -38,6 +39,9 @@ Dynamic: license-file
|
|
|
38
39
|
|
|
39
40
|
# gridfm-graphkit
|
|
40
41
|
[](https://gridfm.github.io/gridfm-graphkit/)
|
|
42
|
+

|
|
43
|
+

|
|
44
|
+

|
|
41
45
|
|
|
42
46
|
This library is brought to you by the GridFM team to train, finetune and interact with a foundation model for the electric power grid.
|
|
43
47
|
|
|
@@ -73,7 +77,7 @@ pip install -e .[dev,test]
|
|
|
73
77
|
```
|
|
74
78
|
|
|
75
79
|
|
|
76
|
-
#
|
|
80
|
+
# CLI commands
|
|
77
81
|
|
|
78
82
|
An interface to train, fine-tune, and evaluate GridFM models using configurable YAML files and MLflow tracking.
|
|
79
83
|
|
|
@@ -83,9 +87,10 @@ gridfm_graphkit <command> [OPTIONS]
|
|
|
83
87
|
|
|
84
88
|
Available commands:
|
|
85
89
|
|
|
86
|
-
* `train` – Train a new model
|
|
87
|
-
* `
|
|
88
|
-
* `
|
|
90
|
+
* `train` – Train a new model from scrathc
|
|
91
|
+
* `finetune` – Fine-tune an existing pre-trained model
|
|
92
|
+
* `evaluate` – Evaluate model performance on a dataset
|
|
93
|
+
* `predict` – Run inference and save predictions
|
|
89
94
|
|
|
90
95
|
---
|
|
91
96
|
|
|
@@ -99,75 +104,77 @@ gridfm_graphkit train --config path/to/config.yaml
|
|
|
99
104
|
|
|
100
105
|
| Argument | Type | Description | Default |
|
|
101
106
|
| ---------------- | ------ | ---------------------------------------------------------------- | ------- |
|
|
102
|
-
| `--config` | `str` | **Required
|
|
103
|
-
| `--
|
|
104
|
-
| `--
|
|
107
|
+
| `--config` | `str` | **Required**. Path to the training configuration YAML file. | `None` |
|
|
108
|
+
| `--exp_name` | `str` | **Optional**. MLflow experiment name. | `timestamp` |
|
|
109
|
+
| `--run_name` | `str` | **Optional**. MLflow run name. | `run` |
|
|
110
|
+
| `--log_dir ` | `str` | **Optional**. MLflow logging directory. | `mlruns` |
|
|
105
111
|
| `--data_path` | `str` | **Optional**. Root dataset directory. | `data` |
|
|
106
|
-
| `-c` | `flag` | **Optional**. Enable checkpoint mode. | `False` |
|
|
107
|
-
| `--model_exp_id` | `str` | **Required if `-c` is used**. MLflow experiment ID. | `None` |
|
|
108
|
-
| `--model_run_id` | `str` | **Required if `-c` is used**. MLflow run ID. | `None` |
|
|
109
112
|
|
|
110
113
|
### Examples
|
|
111
114
|
|
|
112
115
|
**Standard Training:**
|
|
113
116
|
|
|
114
117
|
```bash
|
|
115
|
-
gridfm_graphkit train --config config/
|
|
118
|
+
gridfm_graphkit train --config examples/config/case30_ieee_base.yaml --data_path examples/data
|
|
116
119
|
```
|
|
117
120
|
|
|
118
|
-
|
|
121
|
+
---
|
|
122
|
+
|
|
123
|
+
## Fine-Tuning Models
|
|
119
124
|
|
|
120
125
|
```bash
|
|
121
|
-
gridfm_graphkit
|
|
126
|
+
gridfm_graphkit finetune --config path/to/config.yaml --model_path path/to/model.pth
|
|
122
127
|
```
|
|
123
128
|
|
|
124
|
-
|
|
129
|
+
### Arguments
|
|
130
|
+
|
|
131
|
+
| Argument | Type | Description | Default |
|
|
132
|
+
| -------------- | ----- | ----------------------------------------------- | --------- |
|
|
133
|
+
| `--config` | `str` | **Required**. Fine-tuning configuration file. | `None` |
|
|
134
|
+
| `--model_path` | `str` | **Required**. Path to a pre-trained model file. | `None` |
|
|
135
|
+
| `--exp_name` | `str` | MLflow experiment name. | timestamp |
|
|
136
|
+
| `--run_name` | `str` | MLflow run name. | `run` |
|
|
137
|
+
| `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
|
|
138
|
+
| `--data_path` | `str` | Root dataset directory. | `data` |
|
|
125
139
|
|
|
126
|
-
```bash
|
|
127
|
-
gridfm_graphkit train -c --model_exp_id 123 --model_run_id abc
|
|
128
|
-
```
|
|
129
140
|
|
|
130
141
|
---
|
|
131
142
|
|
|
132
143
|
## Evaluating Models
|
|
133
144
|
|
|
134
145
|
```bash
|
|
135
|
-
gridfm_graphkit
|
|
146
|
+
gridfm_graphkit evaluate --config path/to/eval.yaml --model_path path/to/model.pth
|
|
136
147
|
```
|
|
137
148
|
|
|
138
149
|
### Arguments
|
|
139
150
|
|
|
140
|
-
| Argument
|
|
141
|
-
|
|
|
142
|
-
| `--
|
|
143
|
-
| `--
|
|
144
|
-
| `--
|
|
145
|
-
| `--
|
|
146
|
-
| `--
|
|
147
|
-
| `--
|
|
148
|
-
| `--data_path` | `str` | **Optional**. Path to dataset directory. | `data` |
|
|
149
|
-
|
|
150
|
-
### Examples
|
|
151
|
-
|
|
152
|
-
**Evaluate a Logged MLflow Model:**
|
|
153
|
-
|
|
154
|
-
```bash
|
|
155
|
-
gridfm_graphkit predict --config config/eval.yaml --eval_name run_eval --model_exp_id 1 --model_run_id abc
|
|
156
|
-
```
|
|
151
|
+
| Argument | Type | Description | Default |
|
|
152
|
+
| -------------- | ----- | ---------------------------------------- | --------- |
|
|
153
|
+
| `--config` | `str` | **Required**. Path to evaluation config. | `None` |
|
|
154
|
+
| `--model_path` | `str` | Path to the trained model file. | `None` |
|
|
155
|
+
| `--exp_name` | `str` | MLflow experiment name. | timestamp |
|
|
156
|
+
| `--run_name` | `str` | MLflow run name. | `run` |
|
|
157
|
+
| `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
|
|
158
|
+
| `--data_path` | `str` | Dataset directory. | `data` |
|
|
157
159
|
|
|
158
160
|
---
|
|
159
161
|
|
|
160
|
-
##
|
|
162
|
+
## Running Predictions
|
|
161
163
|
|
|
162
164
|
```bash
|
|
163
|
-
gridfm_graphkit
|
|
165
|
+
gridfm_graphkit predict --config path/to/config.yaml --model_path path/to/model.pth
|
|
164
166
|
```
|
|
165
167
|
|
|
166
168
|
### Arguments
|
|
167
169
|
|
|
168
|
-
| Argument
|
|
169
|
-
|
|
|
170
|
-
| `--config`
|
|
171
|
-
| `--model_path`
|
|
172
|
-
| `--
|
|
173
|
-
| `--
|
|
170
|
+
| Argument | Type | Description | Default |
|
|
171
|
+
| --------------- | ----- | --------------------------------------------- | --------- |
|
|
172
|
+
| `--config` | `str` | **Required**. Path to prediction config file. | `None` |
|
|
173
|
+
| `--model_path` | `str` | Path to the trained model file. | `None` |
|
|
174
|
+
| `--exp_name` | `str` | MLflow experiment name. | timestamp |
|
|
175
|
+
| `--run_name` | `str` | MLflow run name. | `run` |
|
|
176
|
+
| `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
|
|
177
|
+
| `--data_path` | `str` | Dataset directory. | `data` |
|
|
178
|
+
| `--output_path` | `str` | Directory where predictions are saved. | `data` |
|
|
179
|
+
|
|
180
|
+
---
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
# gridfm-graphkit
|
|
2
|
+
[](https://gridfm.github.io/gridfm-graphkit/)
|
|
3
|
+

|
|
4
|
+

|
|
5
|
+

|
|
6
|
+
|
|
7
|
+
This library is brought to you by the GridFM team to train, finetune and interact with a foundation model for the electric power grid.
|
|
8
|
+
|
|
9
|
+
---
|
|
10
|
+
|
|
11
|
+
<p align="center">
|
|
12
|
+
<img src="https://raw.githubusercontent.com/gridfm/gridfm-graphkit/refs/heads/main/docs/figs/pre_training.png" alt="GridFM logo"/>
|
|
13
|
+
<br/>
|
|
14
|
+
</p>
|
|
15
|
+
|
|
16
|
+
# Installation
|
|
17
|
+
|
|
18
|
+
You can install `gridfm-graphkit` directly from PyPI:
|
|
19
|
+
|
|
20
|
+
```bash
|
|
21
|
+
pip install gridfm-graphkit
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
To contribute or develop locally, clone the repository and install in editable mode:
|
|
25
|
+
|
|
26
|
+
```bash
|
|
27
|
+
git clone git@github.com:gridfm/gridfm-graphkit.git
|
|
28
|
+
cd gridfm-graphkit
|
|
29
|
+
python -m venv venv
|
|
30
|
+
source venv/bin/activate
|
|
31
|
+
pip install -e .
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
For documentation generation and unit testing, install with the optional `dev` and `test` extras:
|
|
35
|
+
|
|
36
|
+
```bash
|
|
37
|
+
pip install -e .[dev,test]
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# CLI commands
|
|
42
|
+
|
|
43
|
+
An interface to train, fine-tune, and evaluate GridFM models using configurable YAML files and MLflow tracking.
|
|
44
|
+
|
|
45
|
+
```bash
|
|
46
|
+
gridfm_graphkit <command> [OPTIONS]
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
Available commands:
|
|
50
|
+
|
|
51
|
+
* `train` – Train a new model from scrathc
|
|
52
|
+
* `finetune` – Fine-tune an existing pre-trained model
|
|
53
|
+
* `evaluate` – Evaluate model performance on a dataset
|
|
54
|
+
* `predict` – Run inference and save predictions
|
|
55
|
+
|
|
56
|
+
---
|
|
57
|
+
|
|
58
|
+
## Training Models
|
|
59
|
+
|
|
60
|
+
```bash
|
|
61
|
+
gridfm_graphkit train --config path/to/config.yaml
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
### Arguments
|
|
65
|
+
|
|
66
|
+
| Argument | Type | Description | Default |
|
|
67
|
+
| ---------------- | ------ | ---------------------------------------------------------------- | ------- |
|
|
68
|
+
| `--config` | `str` | **Required**. Path to the training configuration YAML file. | `None` |
|
|
69
|
+
| `--exp_name` | `str` | **Optional**. MLflow experiment name. | `timestamp` |
|
|
70
|
+
| `--run_name` | `str` | **Optional**. MLflow run name. | `run` |
|
|
71
|
+
| `--log_dir ` | `str` | **Optional**. MLflow logging directory. | `mlruns` |
|
|
72
|
+
| `--data_path` | `str` | **Optional**. Root dataset directory. | `data` |
|
|
73
|
+
|
|
74
|
+
### Examples
|
|
75
|
+
|
|
76
|
+
**Standard Training:**
|
|
77
|
+
|
|
78
|
+
```bash
|
|
79
|
+
gridfm_graphkit train --config examples/config/case30_ieee_base.yaml --data_path examples/data
|
|
80
|
+
```
|
|
81
|
+
|
|
82
|
+
---
|
|
83
|
+
|
|
84
|
+
## Fine-Tuning Models
|
|
85
|
+
|
|
86
|
+
```bash
|
|
87
|
+
gridfm_graphkit finetune --config path/to/config.yaml --model_path path/to/model.pth
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
### Arguments
|
|
91
|
+
|
|
92
|
+
| Argument | Type | Description | Default |
|
|
93
|
+
| -------------- | ----- | ----------------------------------------------- | --------- |
|
|
94
|
+
| `--config` | `str` | **Required**. Fine-tuning configuration file. | `None` |
|
|
95
|
+
| `--model_path` | `str` | **Required**. Path to a pre-trained model file. | `None` |
|
|
96
|
+
| `--exp_name` | `str` | MLflow experiment name. | timestamp |
|
|
97
|
+
| `--run_name` | `str` | MLflow run name. | `run` |
|
|
98
|
+
| `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
|
|
99
|
+
| `--data_path` | `str` | Root dataset directory. | `data` |
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
---
|
|
103
|
+
|
|
104
|
+
## Evaluating Models
|
|
105
|
+
|
|
106
|
+
```bash
|
|
107
|
+
gridfm_graphkit evaluate --config path/to/eval.yaml --model_path path/to/model.pth
|
|
108
|
+
```
|
|
109
|
+
|
|
110
|
+
### Arguments
|
|
111
|
+
|
|
112
|
+
| Argument | Type | Description | Default |
|
|
113
|
+
| -------------- | ----- | ---------------------------------------- | --------- |
|
|
114
|
+
| `--config` | `str` | **Required**. Path to evaluation config. | `None` |
|
|
115
|
+
| `--model_path` | `str` | Path to the trained model file. | `None` |
|
|
116
|
+
| `--exp_name` | `str` | MLflow experiment name. | timestamp |
|
|
117
|
+
| `--run_name` | `str` | MLflow run name. | `run` |
|
|
118
|
+
| `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
|
|
119
|
+
| `--data_path` | `str` | Dataset directory. | `data` |
|
|
120
|
+
|
|
121
|
+
---
|
|
122
|
+
|
|
123
|
+
## Running Predictions
|
|
124
|
+
|
|
125
|
+
```bash
|
|
126
|
+
gridfm_graphkit predict --config path/to/config.yaml --model_path path/to/model.pth
|
|
127
|
+
```
|
|
128
|
+
|
|
129
|
+
### Arguments
|
|
130
|
+
|
|
131
|
+
| Argument | Type | Description | Default |
|
|
132
|
+
| --------------- | ----- | --------------------------------------------- | --------- |
|
|
133
|
+
| `--config` | `str` | **Required**. Path to prediction config file. | `None` |
|
|
134
|
+
| `--model_path` | `str` | Path to the trained model file. | `None` |
|
|
135
|
+
| `--exp_name` | `str` | MLflow experiment name. | timestamp |
|
|
136
|
+
| `--run_name` | `str` | MLflow run name. | `run` |
|
|
137
|
+
| `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
|
|
138
|
+
| `--data_path` | `str` | Dataset directory. | `data` |
|
|
139
|
+
| `--output_path` | `str` | Directory where predictions are saved. | `data` |
|
|
140
|
+
|
|
141
|
+
---
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from gridfm_graphkit.cli import main_cli
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def main():
|
|
7
|
+
parser = argparse.ArgumentParser(
|
|
8
|
+
prog="gridfm_graphkit",
|
|
9
|
+
description="gridfm-graphkit CLI",
|
|
10
|
+
)
|
|
11
|
+
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
12
|
+
exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
13
|
+
|
|
14
|
+
# ---- TRAIN SUBCOMMAND ----
|
|
15
|
+
train_parser = subparsers.add_parser("train", help="Run training")
|
|
16
|
+
train_parser.add_argument("--config", type=str, required=True)
|
|
17
|
+
train_parser.add_argument("--exp_name", type=str, default=exp_name)
|
|
18
|
+
train_parser.add_argument("--run_name", type=str, default="run")
|
|
19
|
+
train_parser.add_argument("--log_dir", type=str, default="mlruns")
|
|
20
|
+
train_parser.add_argument("--data_path", type=str, default="data")
|
|
21
|
+
|
|
22
|
+
# ---- FINETUNE SUBCOMMAND ----
|
|
23
|
+
finetune_parser = subparsers.add_parser("finetune", help="Run fine-tuning")
|
|
24
|
+
finetune_parser.add_argument("--config", type=str, required=True)
|
|
25
|
+
finetune_parser.add_argument("--model_path", type=str, required=True)
|
|
26
|
+
finetune_parser.add_argument("--exp_name", type=str, default=exp_name)
|
|
27
|
+
finetune_parser.add_argument("--run_name", type=str, default="run")
|
|
28
|
+
finetune_parser.add_argument("--log_dir", type=str, default="mlruns")
|
|
29
|
+
finetune_parser.add_argument("--data_path", type=str, default="data")
|
|
30
|
+
|
|
31
|
+
# ---- EVALUATE SUBCOMMAND ----
|
|
32
|
+
evaluate_parser = subparsers.add_parser(
|
|
33
|
+
"evaluate",
|
|
34
|
+
help="Evaluate model performance",
|
|
35
|
+
)
|
|
36
|
+
evaluate_parser.add_argument("--model_path", type=str, default=None)
|
|
37
|
+
evaluate_parser.add_argument("--config", type=str, required=True)
|
|
38
|
+
evaluate_parser.add_argument("--exp_name", type=str, default=exp_name)
|
|
39
|
+
evaluate_parser.add_argument("--run_name", type=str, default="run")
|
|
40
|
+
evaluate_parser.add_argument("--log_dir", type=str, default="mlruns")
|
|
41
|
+
evaluate_parser.add_argument("--data_path", type=str, default="data")
|
|
42
|
+
|
|
43
|
+
# ---- PREDICT SUBCOMMAND ----
|
|
44
|
+
predict_parser = subparsers.add_parser("predict", help="Evaluate model performance")
|
|
45
|
+
predict_parser.add_argument("--model_path", type=str, required=None)
|
|
46
|
+
predict_parser.add_argument("--config", type=str, required=True)
|
|
47
|
+
predict_parser.add_argument("--exp_name", type=str, default=exp_name)
|
|
48
|
+
predict_parser.add_argument("--run_name", type=str, default="run")
|
|
49
|
+
predict_parser.add_argument("--log_dir", type=str, default="mlruns")
|
|
50
|
+
predict_parser.add_argument("--data_path", type=str, default="data")
|
|
51
|
+
predict_parser.add_argument("--output_path", type=str, default="data")
|
|
52
|
+
|
|
53
|
+
args = parser.parse_args()
|
|
54
|
+
main_cli(args)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
if __name__ == "__main__":
|
|
58
|
+
main()
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
from gridfm_graphkit.datasets.powergrid_datamodule import LitGridDataModule
|
|
2
|
+
from gridfm_graphkit.io.param_handler import NestedNamespace
|
|
3
|
+
from gridfm_graphkit.training.callbacks import SaveBestModelStateDict
|
|
4
|
+
import numpy as np
|
|
5
|
+
import os
|
|
6
|
+
import yaml
|
|
7
|
+
import torch
|
|
8
|
+
import random
|
|
9
|
+
import pandas as pd
|
|
10
|
+
|
|
11
|
+
from gridfm_graphkit.tasks.feature_reconstruction_task import FeatureReconstructionTask
|
|
12
|
+
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
|
|
13
|
+
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
|
|
14
|
+
from lightning.pytorch.loggers import MLFlowLogger
|
|
15
|
+
import lightning as L
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_training_callbacks(args):
|
|
19
|
+
early_stop_callback = EarlyStopping(
|
|
20
|
+
monitor="Validation loss",
|
|
21
|
+
min_delta=args.callbacks.tol,
|
|
22
|
+
patience=args.callbacks.patience,
|
|
23
|
+
verbose=False,
|
|
24
|
+
mode="min",
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
save_best_model_callback = SaveBestModelStateDict(
|
|
28
|
+
monitor="Validation loss",
|
|
29
|
+
mode="min",
|
|
30
|
+
filename="best_model_state_dict.pt",
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
checkpoint_callback = ModelCheckpoint(
|
|
34
|
+
monitor="Validation loss", # or whichever metric you track
|
|
35
|
+
mode="min",
|
|
36
|
+
save_last=True,
|
|
37
|
+
save_top_k=0,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
return [early_stop_callback, save_best_model_callback, checkpoint_callback]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def main_cli(args):
|
|
44
|
+
logger = MLFlowLogger(
|
|
45
|
+
save_dir=args.log_dir,
|
|
46
|
+
experiment_name=args.exp_name,
|
|
47
|
+
run_name=args.run_name,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
with open(args.config, "r") as f:
|
|
51
|
+
base_config = yaml.safe_load(f)
|
|
52
|
+
|
|
53
|
+
config_args = NestedNamespace(**base_config)
|
|
54
|
+
|
|
55
|
+
torch.manual_seed(config_args.seed)
|
|
56
|
+
random.seed(config_args.seed)
|
|
57
|
+
np.random.seed(config_args.seed)
|
|
58
|
+
|
|
59
|
+
litGrid = LitGridDataModule(config_args, args.data_path)
|
|
60
|
+
model = FeatureReconstructionTask(
|
|
61
|
+
config_args,
|
|
62
|
+
litGrid.node_normalizers,
|
|
63
|
+
litGrid.edge_normalizers,
|
|
64
|
+
)
|
|
65
|
+
if args.command != "train":
|
|
66
|
+
print(f"Loading model weights from {args.model_path}")
|
|
67
|
+
state_dict = torch.load(args.model_path)
|
|
68
|
+
model.load_state_dict(state_dict)
|
|
69
|
+
|
|
70
|
+
trainer = L.Trainer(
|
|
71
|
+
logger=logger,
|
|
72
|
+
accelerator=config_args.training.accelerator,
|
|
73
|
+
devices=config_args.training.devices,
|
|
74
|
+
strategy=config_args.training.strategy,
|
|
75
|
+
log_every_n_steps=1,
|
|
76
|
+
default_root_dir=args.log_dir,
|
|
77
|
+
max_epochs=config_args.training.epochs,
|
|
78
|
+
callbacks=get_training_callbacks(config_args),
|
|
79
|
+
)
|
|
80
|
+
if args.command == "train" or args.command == "finetune":
|
|
81
|
+
trainer.fit(model=model, datamodule=litGrid)
|
|
82
|
+
|
|
83
|
+
if args.command != "predict":
|
|
84
|
+
trainer.test(model=model, datamodule=litGrid)
|
|
85
|
+
|
|
86
|
+
if args.command == "predict":
|
|
87
|
+
predictions = trainer.predict(model=model, datamodule=litGrid)
|
|
88
|
+
all_outputs = []
|
|
89
|
+
all_mask_PQ = []
|
|
90
|
+
all_mask_PV = []
|
|
91
|
+
all_mask_REF = []
|
|
92
|
+
all_scenarios = []
|
|
93
|
+
all_bus_numbers = []
|
|
94
|
+
|
|
95
|
+
for batch in predictions:
|
|
96
|
+
all_outputs.append(batch["output"])
|
|
97
|
+
all_mask_PQ.append(batch["mask_PQ"])
|
|
98
|
+
all_mask_PV.append(batch["mask_PV"])
|
|
99
|
+
all_mask_REF.append(batch["mask_REF"])
|
|
100
|
+
all_scenarios.append(batch["scenario_id"])
|
|
101
|
+
all_bus_numbers.append(batch["bus_number"])
|
|
102
|
+
|
|
103
|
+
# Concatenate all
|
|
104
|
+
outputs = np.concatenate(all_outputs, axis=0) # shape: [num_nodes, 6]
|
|
105
|
+
mask_PQ = np.concatenate(all_mask_PQ, axis=0)
|
|
106
|
+
mask_PV = np.concatenate(all_mask_PV, axis=0)
|
|
107
|
+
mask_REF = np.concatenate(all_mask_REF, axis=0)
|
|
108
|
+
scenario_ids = np.concatenate(all_scenarios, axis=0)
|
|
109
|
+
bus_numbers = np.concatenate(all_bus_numbers, axis=0)
|
|
110
|
+
|
|
111
|
+
# Build DataFrame
|
|
112
|
+
df = pd.DataFrame(
|
|
113
|
+
{
|
|
114
|
+
"scenario": scenario_ids,
|
|
115
|
+
"bus": bus_numbers,
|
|
116
|
+
"PD": outputs[:, 0],
|
|
117
|
+
"QD": outputs[:, 1],
|
|
118
|
+
"PG": outputs[:, 2],
|
|
119
|
+
"QG": outputs[:, 3],
|
|
120
|
+
"VM": outputs[:, 4],
|
|
121
|
+
"VA": outputs[:, 5],
|
|
122
|
+
"PQ": mask_PQ.astype(int),
|
|
123
|
+
"PV": mask_PV.astype(int),
|
|
124
|
+
"REF": mask_REF.astype(int),
|
|
125
|
+
},
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Save CSV
|
|
129
|
+
output_dir = os.path.join(args.output_path)
|
|
130
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
131
|
+
csv_path = os.path.join(output_dir, "predictions.csv")
|
|
132
|
+
df.to_csv(csv_path, index=False)
|
|
133
|
+
|
|
134
|
+
print(f"Saved predictions to {csv_path}")
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from gridfm_graphkit.datasets.transforms import (
|
|
2
|
+
AddPFMask,
|
|
3
|
+
AddIdentityMask,
|
|
4
|
+
AddRandomMask,
|
|
5
|
+
AddOPFMask,
|
|
6
|
+
)
|
|
7
|
+
from gridfm_graphkit.datasets.normalizers import (
|
|
8
|
+
Standardizer,
|
|
9
|
+
MinMaxNormalizer,
|
|
10
|
+
BaseMVANormalizer,
|
|
11
|
+
IdentityNormalizer,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"AddPFMask",
|
|
16
|
+
"AddIdentityMask",
|
|
17
|
+
"AddRandomMask",
|
|
18
|
+
"AddOPFMask",
|
|
19
|
+
"Standardizer",
|
|
20
|
+
"MinMaxNormalizer",
|
|
21
|
+
"BaseMVANormalizer",
|
|
22
|
+
"IdentityNormalizer",
|
|
23
|
+
]
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from gridfm_graphkit.datasets.globals import PD, QD, PG, QG, VA
|
|
2
|
+
from gridfm_graphkit.io.registries import NORMALIZERS_REGISTRY
|
|
2
3
|
import torch
|
|
3
4
|
from abc import ABC, abstractmethod
|
|
4
5
|
|
|
@@ -53,13 +54,25 @@ class Normalizer(ABC):
|
|
|
53
54
|
Original tensor.
|
|
54
55
|
"""
|
|
55
56
|
|
|
57
|
+
@abstractmethod
|
|
58
|
+
def get_stats(self) -> dict:
|
|
59
|
+
"""
|
|
60
|
+
Return the stored normalization statistics for logging/inspection.
|
|
61
|
+
"""
|
|
62
|
+
|
|
56
63
|
|
|
64
|
+
@NORMALIZERS_REGISTRY.register("minmax")
|
|
57
65
|
class MinMaxNormalizer(Normalizer):
|
|
58
66
|
"""
|
|
59
67
|
Scales each feature to the [0, 1] range.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
node_data (bool): Whether data is node-level or edge-level
|
|
71
|
+
args (NestedNamespace): Parameters
|
|
72
|
+
|
|
60
73
|
"""
|
|
61
74
|
|
|
62
|
-
def __init__(self):
|
|
75
|
+
def __init__(self, node_data: bool, args):
|
|
63
76
|
self.min_val = None
|
|
64
77
|
self.max_val = None
|
|
65
78
|
|
|
@@ -95,13 +108,25 @@ class MinMaxNormalizer(Normalizer):
|
|
|
95
108
|
diff[diff == 0] = 1
|
|
96
109
|
return (normalized_data * diff) + self.min_val
|
|
97
110
|
|
|
111
|
+
def get_stats(self) -> dict:
|
|
112
|
+
return {
|
|
113
|
+
"min_value": self.min_val.tolist() if self.min_val is not None else None,
|
|
114
|
+
"max_value": self.max_val.tolist() if self.max_val is not None else None,
|
|
115
|
+
}
|
|
98
116
|
|
|
117
|
+
|
|
118
|
+
@NORMALIZERS_REGISTRY.register("standard")
|
|
99
119
|
class Standardizer(Normalizer):
|
|
100
120
|
"""
|
|
101
121
|
Standardizes each feature to zero mean and unit variance.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
node_data (bool): Whether data is node-level or edge-level
|
|
125
|
+
args (NestedNamespace): Parameters
|
|
126
|
+
|
|
102
127
|
"""
|
|
103
128
|
|
|
104
|
-
def __init__(self):
|
|
129
|
+
def __init__(self, node_data: bool, args):
|
|
105
130
|
self.mean = None
|
|
106
131
|
self.std = None
|
|
107
132
|
|
|
@@ -137,7 +162,14 @@ class Standardizer(Normalizer):
|
|
|
137
162
|
std[std == 0] = 1
|
|
138
163
|
return (normalized_data * std) + self.mean
|
|
139
164
|
|
|
165
|
+
def get_stats(self) -> dict:
|
|
166
|
+
return {
|
|
167
|
+
"mean": self.mean.tolist() if self.mean is not None else None,
|
|
168
|
+
"std": self.std.tolist() if self.std is not None else None,
|
|
169
|
+
}
|
|
170
|
+
|
|
140
171
|
|
|
172
|
+
@NORMALIZERS_REGISTRY.register("baseMVAnorm")
|
|
141
173
|
class BaseMVANormalizer(Normalizer):
|
|
142
174
|
"""
|
|
143
175
|
In power systems, a suitable normalization strategy must preserve the physical properties of
|
|
@@ -148,14 +180,17 @@ class BaseMVANormalizer(Normalizer):
|
|
|
148
180
|
preserving fundamental physical relationships.
|
|
149
181
|
"""
|
|
150
182
|
|
|
151
|
-
def __init__(self, node_data: bool,
|
|
183
|
+
def __init__(self, node_data: bool, args):
|
|
152
184
|
"""
|
|
153
185
|
Args:
|
|
154
|
-
node_data: Whether data is node-level or edge-level
|
|
155
|
-
|
|
186
|
+
node_data: Whether data is node-level or edge-level
|
|
187
|
+
args (NestedNamespace): Parameters
|
|
188
|
+
|
|
189
|
+
Attributes:
|
|
190
|
+
baseMVA (float): baseMVA found in casefile. From ``args.data.baseMVA``.
|
|
156
191
|
"""
|
|
157
192
|
self.node_data = node_data
|
|
158
|
-
self.baseMVA_orig =
|
|
193
|
+
self.baseMVA_orig = getattr(args.data, "baseMVA", 100)
|
|
159
194
|
self.baseMVA = None
|
|
160
195
|
|
|
161
196
|
def to(self, device):
|
|
@@ -208,12 +243,26 @@ class BaseMVANormalizer(Normalizer):
|
|
|
208
243
|
|
|
209
244
|
return normalized_data
|
|
210
245
|
|
|
246
|
+
def get_stats(self) -> dict:
|
|
247
|
+
return {
|
|
248
|
+
"baseMVA": self.baseMVA,
|
|
249
|
+
"baseMVA_orig": self.baseMVA_orig,
|
|
250
|
+
}
|
|
251
|
+
|
|
211
252
|
|
|
253
|
+
@NORMALIZERS_REGISTRY.register("identity")
|
|
212
254
|
class IdentityNormalizer(Normalizer):
|
|
213
255
|
"""
|
|
214
256
|
No normalization: returns data unchanged.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
node_data: Whether data is node-level or edge-level
|
|
260
|
+
args (NestedNamespace): Parameters
|
|
215
261
|
"""
|
|
216
262
|
|
|
263
|
+
def __init__(self, node_data: bool, args):
|
|
264
|
+
pass
|
|
265
|
+
|
|
217
266
|
def fit(self, data: torch.Tensor) -> dict:
|
|
218
267
|
return {}
|
|
219
268
|
|
|
@@ -225,3 +274,6 @@ class IdentityNormalizer(Normalizer):
|
|
|
225
274
|
|
|
226
275
|
def inverse_transform(self, normalized_data: torch.Tensor) -> torch.Tensor:
|
|
227
276
|
return normalized_data
|
|
277
|
+
|
|
278
|
+
def get_stats(self) -> dict:
|
|
279
|
+
return {"note": "No normalization applied."}
|