gridfm-graphkit 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gridfm_graphkit/__init__.py +0 -0
- gridfm_graphkit/__main__.py +62 -0
- gridfm_graphkit/cli.py +530 -0
- gridfm_graphkit/datasets/__init__.py +0 -0
- gridfm_graphkit/datasets/data_normalization.py +227 -0
- gridfm_graphkit/datasets/globals.py +19 -0
- gridfm_graphkit/datasets/powergrid.py +192 -0
- gridfm_graphkit/datasets/transforms.py +223 -0
- gridfm_graphkit/datasets/utils.py +65 -0
- gridfm_graphkit/io/__init__.py +0 -0
- gridfm_graphkit/io/param_handler.py +293 -0
- gridfm_graphkit/models/__init__.py +0 -0
- gridfm_graphkit/models/gps_transformer.py +143 -0
- gridfm_graphkit/models/graphTransformer.py +96 -0
- gridfm_graphkit/training/__init__.py +0 -0
- gridfm_graphkit/training/callbacks.py +47 -0
- gridfm_graphkit/training/plugins.py +218 -0
- gridfm_graphkit/training/trainer.py +156 -0
- gridfm_graphkit/utils/__init__.py +0 -0
- gridfm_graphkit/utils/loss.py +198 -0
- gridfm_graphkit/utils/visualization.py +324 -0
- gridfm_graphkit-0.0.1.dist-info/METADATA +163 -0
- gridfm_graphkit-0.0.1.dist-info/RECORD +27 -0
- gridfm_graphkit-0.0.1.dist-info/WHEEL +5 -0
- gridfm_graphkit-0.0.1.dist-info/entry_points.txt +2 -0
- gridfm_graphkit-0.0.1.dist-info/licenses/LICENSE +201 -0
- gridfm_graphkit-0.0.1.dist-info/top_level.txt +1 -0
|
File without changes
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import torch
|
|
3
|
+
import mlflow
|
|
4
|
+
from gridfm_graphkit.cli import (
|
|
5
|
+
main_standard,
|
|
6
|
+
main_checkpoint,
|
|
7
|
+
main_eval,
|
|
8
|
+
main_fine_tuning,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def main():
|
|
13
|
+
parser = argparse.ArgumentParser(
|
|
14
|
+
prog="gridfm_graphkit",
|
|
15
|
+
description="gridfm-graphkit CLI",
|
|
16
|
+
)
|
|
17
|
+
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
18
|
+
|
|
19
|
+
# ---- TRAIN SUBCOMMAND ----
|
|
20
|
+
train_parser = subparsers.add_parser("train", help="Run training")
|
|
21
|
+
train_parser.add_argument("--config", type=str, default=None)
|
|
22
|
+
train_parser.add_argument("--grid", type=str, default=None)
|
|
23
|
+
train_parser.add_argument("--exp", type=str, default=None)
|
|
24
|
+
train_parser.add_argument("--data_path", type=str, default="data")
|
|
25
|
+
train_parser.add_argument("-c", action="store_true", help="Start from checkpoint")
|
|
26
|
+
train_parser.add_argument("--model_exp_id", type=str, default=None)
|
|
27
|
+
train_parser.add_argument("--model_run_id", type=str, default=None)
|
|
28
|
+
|
|
29
|
+
# ---- FINETUNE SUBCOMMAND ----
|
|
30
|
+
train_parser = subparsers.add_parser("finetune", help="Run fine-tuning")
|
|
31
|
+
train_parser.add_argument("--config", type=str, required=True)
|
|
32
|
+
train_parser.add_argument("--model_path", type=str, required=True)
|
|
33
|
+
train_parser.add_argument("--exp", type=str, default=None)
|
|
34
|
+
train_parser.add_argument("--data_path", type=str, default="data")
|
|
35
|
+
|
|
36
|
+
# ---- PREDICT SUBCOMMAND ----
|
|
37
|
+
predict_parser = subparsers.add_parser("predict", help="Run prediction")
|
|
38
|
+
predict_parser.add_argument("--model_path", type=str, default=None)
|
|
39
|
+
predict_parser.add_argument("--config", type=str, required=True)
|
|
40
|
+
predict_parser.add_argument("--eval_name", type=str, required=True)
|
|
41
|
+
predict_parser.add_argument("--model_exp_id", type=str, default=None)
|
|
42
|
+
predict_parser.add_argument("--model_run_id", type=str, default=None)
|
|
43
|
+
predict_parser.add_argument("--model_name", type=str, default="best_model")
|
|
44
|
+
predict_parser.add_argument("--data_path", type=str, default="data")
|
|
45
|
+
|
|
46
|
+
args = parser.parse_args()
|
|
47
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
48
|
+
mlflow.set_tracking_uri("file:mlruns")
|
|
49
|
+
|
|
50
|
+
if args.command == "train":
|
|
51
|
+
if args.c:
|
|
52
|
+
main_checkpoint(args, device)
|
|
53
|
+
else:
|
|
54
|
+
main_standard(args, device)
|
|
55
|
+
elif args.command == "predict":
|
|
56
|
+
main_eval(args, device)
|
|
57
|
+
elif args.command == "finetune":
|
|
58
|
+
main_fine_tuning(args, device)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
if __name__ == "__main__":
|
|
62
|
+
main()
|
gridfm_graphkit/cli.py
ADDED
|
@@ -0,0 +1,530 @@
|
|
|
1
|
+
from gridfm_graphkit.datasets.powergrid import GridDatasetMem
|
|
2
|
+
from gridfm_graphkit.training.trainer import Trainer
|
|
3
|
+
from gridfm_graphkit.training.plugins import MLflowLoggerPlugin, CheckpointerPlugin
|
|
4
|
+
from gridfm_graphkit.training.callbacks import EarlyStopper
|
|
5
|
+
from gridfm_graphkit.datasets.utils import split_dataset
|
|
6
|
+
from gridfm_graphkit.evaluation.node_level import eval_node_level_task
|
|
7
|
+
from gridfm_graphkit.io.param_handler import (
|
|
8
|
+
NestedNamespace,
|
|
9
|
+
merge_dict,
|
|
10
|
+
load_normalizer,
|
|
11
|
+
load_model,
|
|
12
|
+
get_loss_function,
|
|
13
|
+
get_transform,
|
|
14
|
+
param_combination_gen,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torch_geometric.loader import DataLoader
|
|
19
|
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
20
|
+
from torch.utils.data import ConcatDataset
|
|
21
|
+
from torch.utils.data import Subset
|
|
22
|
+
import numpy as np
|
|
23
|
+
import os
|
|
24
|
+
import mlflow
|
|
25
|
+
from datetime import datetime
|
|
26
|
+
import yaml
|
|
27
|
+
import random
|
|
28
|
+
import plotly.io as pio
|
|
29
|
+
import warnings
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def run_training(
|
|
33
|
+
config_path,
|
|
34
|
+
grid_params,
|
|
35
|
+
data_path,
|
|
36
|
+
device,
|
|
37
|
+
run,
|
|
38
|
+
checkpoint_flag,
|
|
39
|
+
model_path=None,
|
|
40
|
+
):
|
|
41
|
+
# Define log directories
|
|
42
|
+
artifact_dir = os.path.join(
|
|
43
|
+
"mlruns",
|
|
44
|
+
run.info.experiment_id,
|
|
45
|
+
run.info.run_id,
|
|
46
|
+
"artifacts",
|
|
47
|
+
)
|
|
48
|
+
config_dir = os.path.join(artifact_dir, "config")
|
|
49
|
+
model_dir = os.path.join(artifact_dir, "model")
|
|
50
|
+
data_dir = os.path.join(artifact_dir, "data_idx")
|
|
51
|
+
test_dir = os.path.join(artifact_dir, "test")
|
|
52
|
+
|
|
53
|
+
# Create log directories if they don't exists
|
|
54
|
+
os.makedirs(artifact_dir, exist_ok=True)
|
|
55
|
+
os.makedirs(config_dir, exist_ok=True)
|
|
56
|
+
os.makedirs(model_dir, exist_ok=True)
|
|
57
|
+
os.makedirs(data_dir, exist_ok=True)
|
|
58
|
+
os.makedirs(test_dir, exist_ok=True)
|
|
59
|
+
|
|
60
|
+
# Load the base config
|
|
61
|
+
if checkpoint_flag:
|
|
62
|
+
config_path = os.path.join(config_dir, "config.yaml")
|
|
63
|
+
with open(config_path, "r") as f:
|
|
64
|
+
base_config = yaml.safe_load(f)
|
|
65
|
+
else:
|
|
66
|
+
with open(config_path, "r") as f:
|
|
67
|
+
base_config = yaml.safe_load(f)
|
|
68
|
+
|
|
69
|
+
# Deep merge the base config with grid parameters
|
|
70
|
+
merge_dict(base_config, grid_params)
|
|
71
|
+
|
|
72
|
+
# Save updated config file
|
|
73
|
+
config_dest = os.path.join(config_dir, "config.yaml")
|
|
74
|
+
with open(config_dest, "w") as f:
|
|
75
|
+
yaml.dump(base_config, f)
|
|
76
|
+
|
|
77
|
+
args = NestedNamespace(**base_config)
|
|
78
|
+
|
|
79
|
+
# Fix random seed
|
|
80
|
+
torch.manual_seed(args.seed)
|
|
81
|
+
random.seed(args.seed)
|
|
82
|
+
np.random.seed(args.seed)
|
|
83
|
+
|
|
84
|
+
node_normalizers = []
|
|
85
|
+
edge_normalizers = []
|
|
86
|
+
datasets = []
|
|
87
|
+
train_datasets = []
|
|
88
|
+
val_datasets = []
|
|
89
|
+
test_datasets = []
|
|
90
|
+
|
|
91
|
+
for i, network in enumerate(args.data.networks):
|
|
92
|
+
node_normalizer, edge_normalizer = load_normalizer(args=args)
|
|
93
|
+
node_normalizers.append(node_normalizer)
|
|
94
|
+
edge_normalizers.append(edge_normalizer)
|
|
95
|
+
|
|
96
|
+
# Create torch dataset and split
|
|
97
|
+
data_path_network = os.path.join(data_path, network)
|
|
98
|
+
print(f"Loading {network} dataset")
|
|
99
|
+
dataset = GridDatasetMem(
|
|
100
|
+
root=data_path_network,
|
|
101
|
+
norm_method=args.data.normalization,
|
|
102
|
+
node_normalizer=node_normalizer,
|
|
103
|
+
edge_normalizer=edge_normalizer,
|
|
104
|
+
pe_dim=args.model.pe_dim,
|
|
105
|
+
mask_dim=args.data.mask_dim,
|
|
106
|
+
transform=get_transform(args=args),
|
|
107
|
+
)
|
|
108
|
+
datasets.append(dataset)
|
|
109
|
+
|
|
110
|
+
num_scenarios = args.data.scenarios[i]
|
|
111
|
+
if num_scenarios > len(dataset):
|
|
112
|
+
warnings.warn(
|
|
113
|
+
f"Requested number of scenarios ({num_scenarios}) exceeds dataset size ({len(dataset)}). "
|
|
114
|
+
"Using the full dataset instead.",
|
|
115
|
+
)
|
|
116
|
+
num_scenarios = len(dataset)
|
|
117
|
+
|
|
118
|
+
# Create a subset
|
|
119
|
+
subset_indices = list(range(num_scenarios))
|
|
120
|
+
dataset = Subset(dataset, subset_indices)
|
|
121
|
+
|
|
122
|
+
node_normalizer.to(device)
|
|
123
|
+
edge_normalizer.to(device)
|
|
124
|
+
|
|
125
|
+
train_dataset, val_dataset, test_dataset = split_dataset(
|
|
126
|
+
dataset,
|
|
127
|
+
data_dir,
|
|
128
|
+
args.data.val_ratio,
|
|
129
|
+
args.data.test_ratio,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
train_datasets.append(train_dataset)
|
|
133
|
+
val_datasets.append(val_dataset)
|
|
134
|
+
test_datasets.append(test_dataset)
|
|
135
|
+
|
|
136
|
+
train_dataset_multi = ConcatDataset(train_datasets)
|
|
137
|
+
val_dataset_multi = ConcatDataset(val_datasets)
|
|
138
|
+
|
|
139
|
+
# Create DataLoaders
|
|
140
|
+
train_loader = DataLoader(
|
|
141
|
+
train_dataset_multi,
|
|
142
|
+
batch_size=args.training.batch_size,
|
|
143
|
+
shuffle=True,
|
|
144
|
+
)
|
|
145
|
+
val_loader = DataLoader(
|
|
146
|
+
val_dataset_multi,
|
|
147
|
+
batch_size=args.training.batch_size,
|
|
148
|
+
shuffle=False,
|
|
149
|
+
)
|
|
150
|
+
test_loaders = [
|
|
151
|
+
DataLoader(i, batch_size=args.training.batch_size, shuffle=False)
|
|
152
|
+
for i in test_datasets
|
|
153
|
+
]
|
|
154
|
+
|
|
155
|
+
# Create model
|
|
156
|
+
if model_path:
|
|
157
|
+
model = torch.load(model_path, weights_only=False, map_location=device)
|
|
158
|
+
else:
|
|
159
|
+
model = load_model(args=args)
|
|
160
|
+
|
|
161
|
+
print(model)
|
|
162
|
+
print(
|
|
163
|
+
"Model parameters: ",
|
|
164
|
+
sum(p.numel() for p in model.parameters() if p.requires_grad),
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
if checkpoint_flag:
|
|
168
|
+
checkpoint_path = os.path.join(model_dir, "checkpoint_last_epoch.pth")
|
|
169
|
+
checkpoint = torch.load(checkpoint_path, weights_only=False)
|
|
170
|
+
model.load_state_dict(checkpoint["model_state_dict"])
|
|
171
|
+
|
|
172
|
+
model = model.to(device)
|
|
173
|
+
|
|
174
|
+
# Optimizer and learning rate scheduler
|
|
175
|
+
optimizer = torch.optim.Adam(
|
|
176
|
+
model.parameters(),
|
|
177
|
+
lr=args.optimizer.learning_rate,
|
|
178
|
+
betas=(args.optimizer.beta1, args.optimizer.beta2),
|
|
179
|
+
)
|
|
180
|
+
scheduler = ReduceLROnPlateau(
|
|
181
|
+
optimizer,
|
|
182
|
+
mode="min",
|
|
183
|
+
factor=args.optimizer.lr_decay,
|
|
184
|
+
patience=args.optimizer.lr_patience,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
if checkpoint_flag:
|
|
188
|
+
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
|
189
|
+
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
|
190
|
+
|
|
191
|
+
best_model_path = os.path.join(model_dir, "best_model.pth")
|
|
192
|
+
early_stopper = EarlyStopper(
|
|
193
|
+
best_model_path,
|
|
194
|
+
args.callbacks.patience,
|
|
195
|
+
args.callbacks.tol,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
loss_fn = get_loss_function(args)
|
|
199
|
+
|
|
200
|
+
mlflow_plugin = MLflowLoggerPlugin(steps=10, params=args.flatten())
|
|
201
|
+
checkpointer = CheckpointerPlugin(model_dir, steps=None)
|
|
202
|
+
|
|
203
|
+
trainer = Trainer(
|
|
204
|
+
model=model,
|
|
205
|
+
optimizer=optimizer,
|
|
206
|
+
device=device,
|
|
207
|
+
loss_fn=loss_fn,
|
|
208
|
+
early_stopper=early_stopper,
|
|
209
|
+
train_dataloader=train_loader,
|
|
210
|
+
val_dataloader=val_loader,
|
|
211
|
+
lr_scheduler=scheduler,
|
|
212
|
+
plugins=[mlflow_plugin, checkpointer],
|
|
213
|
+
)
|
|
214
|
+
# Train model
|
|
215
|
+
if checkpoint_flag:
|
|
216
|
+
trainer.train(
|
|
217
|
+
checkpoint["epoch"] + 1,
|
|
218
|
+
args.training.epochs - checkpoint["epoch"] - 1,
|
|
219
|
+
)
|
|
220
|
+
else:
|
|
221
|
+
trainer.train(0, args.training.epochs)
|
|
222
|
+
|
|
223
|
+
# Save mask
|
|
224
|
+
if args.data.learn_mask:
|
|
225
|
+
mask_path = os.path.join(model_dir, "mask_value.txt")
|
|
226
|
+
np.savetxt(mask_path, model.mask_value.numpy(force=True))
|
|
227
|
+
|
|
228
|
+
# load best_model
|
|
229
|
+
best_model = torch.load(best_model_path, weights_only=False)
|
|
230
|
+
|
|
231
|
+
# Save best_mask
|
|
232
|
+
if args.data.learn_mask:
|
|
233
|
+
best_mask_path = os.path.join(model_dir, "best_mask_value.txt")
|
|
234
|
+
np.savetxt(best_mask_path, best_model.mask_value.numpy(force=True))
|
|
235
|
+
|
|
236
|
+
for i, network in enumerate(args.data.networks):
|
|
237
|
+
for task in ["PF", "OPF", "Reconstruction"]:
|
|
238
|
+
mask_ratio = getattr(
|
|
239
|
+
args.data,
|
|
240
|
+
"mask_ratio",
|
|
241
|
+
0.5,
|
|
242
|
+
) # Default to 0.5 if mask_ratio doesn't exist
|
|
243
|
+
df, figs = eval_node_level_task(
|
|
244
|
+
dataset=datasets[i],
|
|
245
|
+
model=best_model,
|
|
246
|
+
task=task,
|
|
247
|
+
test_loader=test_loaders[i],
|
|
248
|
+
mask_dim=args.data.mask_dim,
|
|
249
|
+
mask_ratio=mask_ratio,
|
|
250
|
+
node_normalizer=node_normalizers[i],
|
|
251
|
+
device=device,
|
|
252
|
+
plot_dist=args.verbose,
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# Log metric results
|
|
256
|
+
df_path = os.path.join(test_dir, f"{task}_metrics_results_{network}.csv")
|
|
257
|
+
df.to_csv(df_path)
|
|
258
|
+
|
|
259
|
+
plot_paths = os.path.join(
|
|
260
|
+
test_dir,
|
|
261
|
+
f"{task}_evaluation_plots_{network}.html",
|
|
262
|
+
)
|
|
263
|
+
with open(plot_paths, "a") as f:
|
|
264
|
+
for fig in figs:
|
|
265
|
+
f.write(pio.to_html(fig, full_html=False, include_plotlyjs="cdn"))
|
|
266
|
+
|
|
267
|
+
# Log node and edge stats
|
|
268
|
+
log_file_path = os.path.join(artifact_dir, f"stats_{network}.log")
|
|
269
|
+
with open(log_file_path, "w") as log_file:
|
|
270
|
+
log_file.write("Dataset node_stats: " + str(datasets[i].node_stats) + "\n")
|
|
271
|
+
log_file.write("Dataset edge_stats: " + str(datasets[i].edge_stats) + "\n")
|
|
272
|
+
|
|
273
|
+
eval_cmd_path = os.path.join(artifact_dir, "EVAL_CMD.txt")
|
|
274
|
+
with open(eval_cmd_path, "w") as f:
|
|
275
|
+
f.write(
|
|
276
|
+
f"gridFM predict --model_exp_id {run.info.experiment_id} --model_run_id {run.info.run_id} --model_name best_model --config {config_dest} --eval_name YOUR_EVAL_NAME \n",
|
|
277
|
+
)
|
|
278
|
+
f.write(
|
|
279
|
+
f"gridFM predict --model_exp_id {run.info.experiment_id} --model_run_id {run.info.run_id} --model_name best_model --config {config_dest} --eval_name YOUR_EVAL_NAME",
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def main_standard(args, device):
|
|
284
|
+
exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
285
|
+
if args.exp is None:
|
|
286
|
+
experiment_name = f"exp_{exp_name}"
|
|
287
|
+
else:
|
|
288
|
+
experiment_name = f"{args.exp}"
|
|
289
|
+
|
|
290
|
+
mlflow.set_experiment(experiment_name)
|
|
291
|
+
|
|
292
|
+
if args.grid: # Only perform grid search if grid parameter file is provided
|
|
293
|
+
print(f"Grid search enabled. Using grid file: {args.grid}")
|
|
294
|
+
|
|
295
|
+
# Parse grid parameters
|
|
296
|
+
with open(args.grid, "r") as f:
|
|
297
|
+
grid_config = yaml.safe_load(f)
|
|
298
|
+
|
|
299
|
+
grid_combinations = param_combination_gen(grid_config)
|
|
300
|
+
|
|
301
|
+
# Run experiments for all combinations
|
|
302
|
+
for i, grid_params in enumerate(grid_combinations):
|
|
303
|
+
print(
|
|
304
|
+
f"\nGrid search: {i + 1}/{len(grid_combinations)} with params: {grid_params}",
|
|
305
|
+
)
|
|
306
|
+
with mlflow.start_run() as run:
|
|
307
|
+
run_training(
|
|
308
|
+
args.config,
|
|
309
|
+
grid_params,
|
|
310
|
+
args.data_path,
|
|
311
|
+
device,
|
|
312
|
+
run,
|
|
313
|
+
checkpoint_flag=False,
|
|
314
|
+
)
|
|
315
|
+
else:
|
|
316
|
+
print("No grid search config file provided. Running single training")
|
|
317
|
+
with mlflow.start_run() as run:
|
|
318
|
+
run_training(
|
|
319
|
+
args.config,
|
|
320
|
+
{},
|
|
321
|
+
args.data_path,
|
|
322
|
+
device,
|
|
323
|
+
run,
|
|
324
|
+
checkpoint_flag=False,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def main_checkpoint(args, device):
|
|
329
|
+
if args.grid:
|
|
330
|
+
warnings.warn("Grid search not supported with model checkpoint")
|
|
331
|
+
if args.config:
|
|
332
|
+
warnings.warn("No need to specify config file")
|
|
333
|
+
|
|
334
|
+
with mlflow.start_run(
|
|
335
|
+
experiment_id=args.model_exp_id,
|
|
336
|
+
run_id=args.model_run_id,
|
|
337
|
+
) as run:
|
|
338
|
+
run_training(args.config, {}, args.data_path, device, run, checkpoint_flag=True)
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def main_fine_tuning(args, device):
|
|
342
|
+
exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
343
|
+
if args.exp is None:
|
|
344
|
+
experiment_name = f"exp_{exp_name}"
|
|
345
|
+
else:
|
|
346
|
+
experiment_name = f"{args.exp}"
|
|
347
|
+
|
|
348
|
+
mlflow.set_experiment(experiment_name)
|
|
349
|
+
with mlflow.start_run() as run:
|
|
350
|
+
run_training(
|
|
351
|
+
args.config,
|
|
352
|
+
{},
|
|
353
|
+
args.data_path,
|
|
354
|
+
device,
|
|
355
|
+
run,
|
|
356
|
+
checkpoint_flag=False,
|
|
357
|
+
model_path=args.model_path,
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def eval(model_path, run, config_path, data_path, device):
|
|
362
|
+
model = torch.load(model_path, weights_only=False, map_location=device).to(device)
|
|
363
|
+
|
|
364
|
+
artifact_dir = os.path.join(
|
|
365
|
+
"mlruns",
|
|
366
|
+
run.info.experiment_id,
|
|
367
|
+
run.info.run_id,
|
|
368
|
+
"artifacts",
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
# Define log directories
|
|
372
|
+
config_dir = os.path.join(artifact_dir, "config")
|
|
373
|
+
data_dir = os.path.join(artifact_dir, "data_idx")
|
|
374
|
+
test_dir = os.path.join(artifact_dir, "test")
|
|
375
|
+
|
|
376
|
+
# Create log directories if they don't exist
|
|
377
|
+
os.makedirs(config_dir, exist_ok=True)
|
|
378
|
+
os.makedirs(data_dir, exist_ok=True)
|
|
379
|
+
os.makedirs(test_dir, exist_ok=True)
|
|
380
|
+
|
|
381
|
+
# Load the base config
|
|
382
|
+
with open(config_path, "r") as f:
|
|
383
|
+
base_config = yaml.safe_load(f)
|
|
384
|
+
|
|
385
|
+
# Save config file
|
|
386
|
+
config_dest = os.path.join(config_dir, "config.yaml")
|
|
387
|
+
with open(config_dest, "w") as f:
|
|
388
|
+
yaml.dump(base_config, f)
|
|
389
|
+
|
|
390
|
+
args = NestedNamespace(**base_config)
|
|
391
|
+
|
|
392
|
+
# Fix random seed
|
|
393
|
+
torch.manual_seed(args.seed)
|
|
394
|
+
random.seed(args.seed)
|
|
395
|
+
np.random.seed(args.seed)
|
|
396
|
+
|
|
397
|
+
node_normalizers = []
|
|
398
|
+
edge_normalizers = []
|
|
399
|
+
datasets = []
|
|
400
|
+
test_datasets = []
|
|
401
|
+
|
|
402
|
+
for i, network in enumerate(args.data.networks):
|
|
403
|
+
node_normalizer, edge_normalizer = load_normalizer(args=args)
|
|
404
|
+
node_normalizers.append(node_normalizer)
|
|
405
|
+
edge_normalizers.append(edge_normalizer)
|
|
406
|
+
|
|
407
|
+
# Create torch dataset and split
|
|
408
|
+
data_path_network = os.path.join(data_path, network)
|
|
409
|
+
print(f"Loading {network} dataset")
|
|
410
|
+
dataset = GridDatasetMem(
|
|
411
|
+
root=data_path_network,
|
|
412
|
+
norm_method=args.data.normalization,
|
|
413
|
+
node_normalizer=node_normalizer,
|
|
414
|
+
edge_normalizer=edge_normalizer,
|
|
415
|
+
pe_dim=args.model.pe_dim,
|
|
416
|
+
mask_dim=args.data.mask_dim,
|
|
417
|
+
transform=get_transform(args=args),
|
|
418
|
+
)
|
|
419
|
+
datasets.append(dataset)
|
|
420
|
+
|
|
421
|
+
num_scenarios = args.data.scenarios[i]
|
|
422
|
+
if num_scenarios > len(dataset):
|
|
423
|
+
warnings.warn(
|
|
424
|
+
f"Requested number of scenarios ({num_scenarios}) exceeds dataset size ({len(dataset)}). "
|
|
425
|
+
"Using the full dataset instead.",
|
|
426
|
+
)
|
|
427
|
+
num_scenarios = len(dataset)
|
|
428
|
+
|
|
429
|
+
subset_indices = list(range(num_scenarios))
|
|
430
|
+
dataset = Subset(dataset, subset_indices)
|
|
431
|
+
|
|
432
|
+
node_normalizer.to(device)
|
|
433
|
+
edge_normalizer.to(device)
|
|
434
|
+
|
|
435
|
+
_, _, test_dataset = split_dataset(
|
|
436
|
+
dataset,
|
|
437
|
+
data_dir,
|
|
438
|
+
args.data.val_ratio,
|
|
439
|
+
args.data.test_ratio,
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
test_datasets.append(test_dataset)
|
|
443
|
+
|
|
444
|
+
test_loaders = [
|
|
445
|
+
DataLoader(i, batch_size=args.training.batch_size, shuffle=False)
|
|
446
|
+
for i in test_datasets
|
|
447
|
+
]
|
|
448
|
+
|
|
449
|
+
mlflow.log_params(args.flatten())
|
|
450
|
+
for i, network in enumerate(args.data.networks):
|
|
451
|
+
for task in ["PF", "OPF", "Reconstruction"]:
|
|
452
|
+
mask_ratio = getattr(
|
|
453
|
+
args.data,
|
|
454
|
+
"mask_ratio",
|
|
455
|
+
0.5,
|
|
456
|
+
) # Default to 0.5 if mask_ratio doesn't exist
|
|
457
|
+
df, figs = eval_node_level_task(
|
|
458
|
+
dataset=datasets[i],
|
|
459
|
+
model=model,
|
|
460
|
+
task=task,
|
|
461
|
+
test_loader=test_loaders[i],
|
|
462
|
+
mask_dim=args.data.mask_dim,
|
|
463
|
+
mask_ratio=mask_ratio,
|
|
464
|
+
node_normalizer=node_normalizers[i],
|
|
465
|
+
device=device,
|
|
466
|
+
plot_dist=args.verbose,
|
|
467
|
+
)
|
|
468
|
+
# Log metric results
|
|
469
|
+
df_path = os.path.join(test_dir, f"{task}_metrics_results_{network}.csv")
|
|
470
|
+
df.to_csv(df_path)
|
|
471
|
+
|
|
472
|
+
plot_paths = os.path.join(
|
|
473
|
+
test_dir,
|
|
474
|
+
f"{task}_evaluation_plots_{network}.html",
|
|
475
|
+
)
|
|
476
|
+
with open(plot_paths, "a") as f:
|
|
477
|
+
for fig in figs:
|
|
478
|
+
f.write(pio.to_html(fig, full_html=False, include_plotlyjs="cdn"))
|
|
479
|
+
|
|
480
|
+
# Log node and edge stats
|
|
481
|
+
log_file_path = os.path.join(artifact_dir, f"stats_{network}.log")
|
|
482
|
+
|
|
483
|
+
# Write the print statements to the log file
|
|
484
|
+
with open(log_file_path, "w") as log_file:
|
|
485
|
+
log_file.write("Dataset node_stats: " + str(datasets[i].node_stats) + "\n")
|
|
486
|
+
log_file.write("Dataset edge_stats: " + str(datasets[i].edge_stats) + "\n")
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
def main_eval(
|
|
490
|
+
args,
|
|
491
|
+
device,
|
|
492
|
+
):
|
|
493
|
+
if args.model_path is None and (
|
|
494
|
+
(args.model_exp_id is None)
|
|
495
|
+
or (args.model_run_id is None)
|
|
496
|
+
or (args.model_name is None)
|
|
497
|
+
):
|
|
498
|
+
raise ValueError(
|
|
499
|
+
"Either model_path or (model_exp_id, model_run_id, model_name) must be provided",
|
|
500
|
+
)
|
|
501
|
+
if args.model_path is not None:
|
|
502
|
+
mlflow.set_experiment(args.eval_name)
|
|
503
|
+
with mlflow.start_run() as run:
|
|
504
|
+
eval(args.model_path, run, args.config, args.data_path, device)
|
|
505
|
+
|
|
506
|
+
else:
|
|
507
|
+
# Start the parent run using the provided experiment ID and run ID
|
|
508
|
+
# This is necessary to create a child run
|
|
509
|
+
with mlflow.start_run(
|
|
510
|
+
experiment_id=args.model_exp_id,
|
|
511
|
+
run_id=args.model_run_id,
|
|
512
|
+
) as parent_run:
|
|
513
|
+
# Start a nested run
|
|
514
|
+
with mlflow.start_run(
|
|
515
|
+
experiment_id=args.model_exp_id,
|
|
516
|
+
parent_run_id=args.model_run_id,
|
|
517
|
+
run_name=args.eval_name,
|
|
518
|
+
nested=True,
|
|
519
|
+
) as nested_run:
|
|
520
|
+
# load model from parent run artifact dir
|
|
521
|
+
model_path = os.path.join(
|
|
522
|
+
"mlruns",
|
|
523
|
+
parent_run.info.experiment_id,
|
|
524
|
+
parent_run.info.run_id,
|
|
525
|
+
"artifacts",
|
|
526
|
+
"model",
|
|
527
|
+
args.model_name + ".pth",
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
eval(model_path, nested_run, args.config, args.data_path, device)
|
|
File without changes
|