gridfm-graphkit 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,62 @@
1
+ import argparse
2
+ import torch
3
+ import mlflow
4
+ from gridfm_graphkit.cli import (
5
+ main_standard,
6
+ main_checkpoint,
7
+ main_eval,
8
+ main_fine_tuning,
9
+ )
10
+
11
+
12
+ def main():
13
+ parser = argparse.ArgumentParser(
14
+ prog="gridfm_graphkit",
15
+ description="gridfm-graphkit CLI",
16
+ )
17
+ subparsers = parser.add_subparsers(dest="command", required=True)
18
+
19
+ # ---- TRAIN SUBCOMMAND ----
20
+ train_parser = subparsers.add_parser("train", help="Run training")
21
+ train_parser.add_argument("--config", type=str, default=None)
22
+ train_parser.add_argument("--grid", type=str, default=None)
23
+ train_parser.add_argument("--exp", type=str, default=None)
24
+ train_parser.add_argument("--data_path", type=str, default="data")
25
+ train_parser.add_argument("-c", action="store_true", help="Start from checkpoint")
26
+ train_parser.add_argument("--model_exp_id", type=str, default=None)
27
+ train_parser.add_argument("--model_run_id", type=str, default=None)
28
+
29
+ # ---- FINETUNE SUBCOMMAND ----
30
+ train_parser = subparsers.add_parser("finetune", help="Run fine-tuning")
31
+ train_parser.add_argument("--config", type=str, required=True)
32
+ train_parser.add_argument("--model_path", type=str, required=True)
33
+ train_parser.add_argument("--exp", type=str, default=None)
34
+ train_parser.add_argument("--data_path", type=str, default="data")
35
+
36
+ # ---- PREDICT SUBCOMMAND ----
37
+ predict_parser = subparsers.add_parser("predict", help="Run prediction")
38
+ predict_parser.add_argument("--model_path", type=str, default=None)
39
+ predict_parser.add_argument("--config", type=str, required=True)
40
+ predict_parser.add_argument("--eval_name", type=str, required=True)
41
+ predict_parser.add_argument("--model_exp_id", type=str, default=None)
42
+ predict_parser.add_argument("--model_run_id", type=str, default=None)
43
+ predict_parser.add_argument("--model_name", type=str, default="best_model")
44
+ predict_parser.add_argument("--data_path", type=str, default="data")
45
+
46
+ args = parser.parse_args()
47
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
+ mlflow.set_tracking_uri("file:mlruns")
49
+
50
+ if args.command == "train":
51
+ if args.c:
52
+ main_checkpoint(args, device)
53
+ else:
54
+ main_standard(args, device)
55
+ elif args.command == "predict":
56
+ main_eval(args, device)
57
+ elif args.command == "finetune":
58
+ main_fine_tuning(args, device)
59
+
60
+
61
+ if __name__ == "__main__":
62
+ main()
gridfm_graphkit/cli.py ADDED
@@ -0,0 +1,530 @@
1
+ from gridfm_graphkit.datasets.powergrid import GridDatasetMem
2
+ from gridfm_graphkit.training.trainer import Trainer
3
+ from gridfm_graphkit.training.plugins import MLflowLoggerPlugin, CheckpointerPlugin
4
+ from gridfm_graphkit.training.callbacks import EarlyStopper
5
+ from gridfm_graphkit.datasets.utils import split_dataset
6
+ from gridfm_graphkit.evaluation.node_level import eval_node_level_task
7
+ from gridfm_graphkit.io.param_handler import (
8
+ NestedNamespace,
9
+ merge_dict,
10
+ load_normalizer,
11
+ load_model,
12
+ get_loss_function,
13
+ get_transform,
14
+ param_combination_gen,
15
+ )
16
+
17
+ import torch
18
+ from torch_geometric.loader import DataLoader
19
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
20
+ from torch.utils.data import ConcatDataset
21
+ from torch.utils.data import Subset
22
+ import numpy as np
23
+ import os
24
+ import mlflow
25
+ from datetime import datetime
26
+ import yaml
27
+ import random
28
+ import plotly.io as pio
29
+ import warnings
30
+
31
+
32
+ def run_training(
33
+ config_path,
34
+ grid_params,
35
+ data_path,
36
+ device,
37
+ run,
38
+ checkpoint_flag,
39
+ model_path=None,
40
+ ):
41
+ # Define log directories
42
+ artifact_dir = os.path.join(
43
+ "mlruns",
44
+ run.info.experiment_id,
45
+ run.info.run_id,
46
+ "artifacts",
47
+ )
48
+ config_dir = os.path.join(artifact_dir, "config")
49
+ model_dir = os.path.join(artifact_dir, "model")
50
+ data_dir = os.path.join(artifact_dir, "data_idx")
51
+ test_dir = os.path.join(artifact_dir, "test")
52
+
53
+ # Create log directories if they don't exists
54
+ os.makedirs(artifact_dir, exist_ok=True)
55
+ os.makedirs(config_dir, exist_ok=True)
56
+ os.makedirs(model_dir, exist_ok=True)
57
+ os.makedirs(data_dir, exist_ok=True)
58
+ os.makedirs(test_dir, exist_ok=True)
59
+
60
+ # Load the base config
61
+ if checkpoint_flag:
62
+ config_path = os.path.join(config_dir, "config.yaml")
63
+ with open(config_path, "r") as f:
64
+ base_config = yaml.safe_load(f)
65
+ else:
66
+ with open(config_path, "r") as f:
67
+ base_config = yaml.safe_load(f)
68
+
69
+ # Deep merge the base config with grid parameters
70
+ merge_dict(base_config, grid_params)
71
+
72
+ # Save updated config file
73
+ config_dest = os.path.join(config_dir, "config.yaml")
74
+ with open(config_dest, "w") as f:
75
+ yaml.dump(base_config, f)
76
+
77
+ args = NestedNamespace(**base_config)
78
+
79
+ # Fix random seed
80
+ torch.manual_seed(args.seed)
81
+ random.seed(args.seed)
82
+ np.random.seed(args.seed)
83
+
84
+ node_normalizers = []
85
+ edge_normalizers = []
86
+ datasets = []
87
+ train_datasets = []
88
+ val_datasets = []
89
+ test_datasets = []
90
+
91
+ for i, network in enumerate(args.data.networks):
92
+ node_normalizer, edge_normalizer = load_normalizer(args=args)
93
+ node_normalizers.append(node_normalizer)
94
+ edge_normalizers.append(edge_normalizer)
95
+
96
+ # Create torch dataset and split
97
+ data_path_network = os.path.join(data_path, network)
98
+ print(f"Loading {network} dataset")
99
+ dataset = GridDatasetMem(
100
+ root=data_path_network,
101
+ norm_method=args.data.normalization,
102
+ node_normalizer=node_normalizer,
103
+ edge_normalizer=edge_normalizer,
104
+ pe_dim=args.model.pe_dim,
105
+ mask_dim=args.data.mask_dim,
106
+ transform=get_transform(args=args),
107
+ )
108
+ datasets.append(dataset)
109
+
110
+ num_scenarios = args.data.scenarios[i]
111
+ if num_scenarios > len(dataset):
112
+ warnings.warn(
113
+ f"Requested number of scenarios ({num_scenarios}) exceeds dataset size ({len(dataset)}). "
114
+ "Using the full dataset instead.",
115
+ )
116
+ num_scenarios = len(dataset)
117
+
118
+ # Create a subset
119
+ subset_indices = list(range(num_scenarios))
120
+ dataset = Subset(dataset, subset_indices)
121
+
122
+ node_normalizer.to(device)
123
+ edge_normalizer.to(device)
124
+
125
+ train_dataset, val_dataset, test_dataset = split_dataset(
126
+ dataset,
127
+ data_dir,
128
+ args.data.val_ratio,
129
+ args.data.test_ratio,
130
+ )
131
+
132
+ train_datasets.append(train_dataset)
133
+ val_datasets.append(val_dataset)
134
+ test_datasets.append(test_dataset)
135
+
136
+ train_dataset_multi = ConcatDataset(train_datasets)
137
+ val_dataset_multi = ConcatDataset(val_datasets)
138
+
139
+ # Create DataLoaders
140
+ train_loader = DataLoader(
141
+ train_dataset_multi,
142
+ batch_size=args.training.batch_size,
143
+ shuffle=True,
144
+ )
145
+ val_loader = DataLoader(
146
+ val_dataset_multi,
147
+ batch_size=args.training.batch_size,
148
+ shuffle=False,
149
+ )
150
+ test_loaders = [
151
+ DataLoader(i, batch_size=args.training.batch_size, shuffle=False)
152
+ for i in test_datasets
153
+ ]
154
+
155
+ # Create model
156
+ if model_path:
157
+ model = torch.load(model_path, weights_only=False, map_location=device)
158
+ else:
159
+ model = load_model(args=args)
160
+
161
+ print(model)
162
+ print(
163
+ "Model parameters: ",
164
+ sum(p.numel() for p in model.parameters() if p.requires_grad),
165
+ )
166
+
167
+ if checkpoint_flag:
168
+ checkpoint_path = os.path.join(model_dir, "checkpoint_last_epoch.pth")
169
+ checkpoint = torch.load(checkpoint_path, weights_only=False)
170
+ model.load_state_dict(checkpoint["model_state_dict"])
171
+
172
+ model = model.to(device)
173
+
174
+ # Optimizer and learning rate scheduler
175
+ optimizer = torch.optim.Adam(
176
+ model.parameters(),
177
+ lr=args.optimizer.learning_rate,
178
+ betas=(args.optimizer.beta1, args.optimizer.beta2),
179
+ )
180
+ scheduler = ReduceLROnPlateau(
181
+ optimizer,
182
+ mode="min",
183
+ factor=args.optimizer.lr_decay,
184
+ patience=args.optimizer.lr_patience,
185
+ )
186
+
187
+ if checkpoint_flag:
188
+ optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
189
+ scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
190
+
191
+ best_model_path = os.path.join(model_dir, "best_model.pth")
192
+ early_stopper = EarlyStopper(
193
+ best_model_path,
194
+ args.callbacks.patience,
195
+ args.callbacks.tol,
196
+ )
197
+
198
+ loss_fn = get_loss_function(args)
199
+
200
+ mlflow_plugin = MLflowLoggerPlugin(steps=10, params=args.flatten())
201
+ checkpointer = CheckpointerPlugin(model_dir, steps=None)
202
+
203
+ trainer = Trainer(
204
+ model=model,
205
+ optimizer=optimizer,
206
+ device=device,
207
+ loss_fn=loss_fn,
208
+ early_stopper=early_stopper,
209
+ train_dataloader=train_loader,
210
+ val_dataloader=val_loader,
211
+ lr_scheduler=scheduler,
212
+ plugins=[mlflow_plugin, checkpointer],
213
+ )
214
+ # Train model
215
+ if checkpoint_flag:
216
+ trainer.train(
217
+ checkpoint["epoch"] + 1,
218
+ args.training.epochs - checkpoint["epoch"] - 1,
219
+ )
220
+ else:
221
+ trainer.train(0, args.training.epochs)
222
+
223
+ # Save mask
224
+ if args.data.learn_mask:
225
+ mask_path = os.path.join(model_dir, "mask_value.txt")
226
+ np.savetxt(mask_path, model.mask_value.numpy(force=True))
227
+
228
+ # load best_model
229
+ best_model = torch.load(best_model_path, weights_only=False)
230
+
231
+ # Save best_mask
232
+ if args.data.learn_mask:
233
+ best_mask_path = os.path.join(model_dir, "best_mask_value.txt")
234
+ np.savetxt(best_mask_path, best_model.mask_value.numpy(force=True))
235
+
236
+ for i, network in enumerate(args.data.networks):
237
+ for task in ["PF", "OPF", "Reconstruction"]:
238
+ mask_ratio = getattr(
239
+ args.data,
240
+ "mask_ratio",
241
+ 0.5,
242
+ ) # Default to 0.5 if mask_ratio doesn't exist
243
+ df, figs = eval_node_level_task(
244
+ dataset=datasets[i],
245
+ model=best_model,
246
+ task=task,
247
+ test_loader=test_loaders[i],
248
+ mask_dim=args.data.mask_dim,
249
+ mask_ratio=mask_ratio,
250
+ node_normalizer=node_normalizers[i],
251
+ device=device,
252
+ plot_dist=args.verbose,
253
+ )
254
+
255
+ # Log metric results
256
+ df_path = os.path.join(test_dir, f"{task}_metrics_results_{network}.csv")
257
+ df.to_csv(df_path)
258
+
259
+ plot_paths = os.path.join(
260
+ test_dir,
261
+ f"{task}_evaluation_plots_{network}.html",
262
+ )
263
+ with open(plot_paths, "a") as f:
264
+ for fig in figs:
265
+ f.write(pio.to_html(fig, full_html=False, include_plotlyjs="cdn"))
266
+
267
+ # Log node and edge stats
268
+ log_file_path = os.path.join(artifact_dir, f"stats_{network}.log")
269
+ with open(log_file_path, "w") as log_file:
270
+ log_file.write("Dataset node_stats: " + str(datasets[i].node_stats) + "\n")
271
+ log_file.write("Dataset edge_stats: " + str(datasets[i].edge_stats) + "\n")
272
+
273
+ eval_cmd_path = os.path.join(artifact_dir, "EVAL_CMD.txt")
274
+ with open(eval_cmd_path, "w") as f:
275
+ f.write(
276
+ f"gridFM predict --model_exp_id {run.info.experiment_id} --model_run_id {run.info.run_id} --model_name best_model --config {config_dest} --eval_name YOUR_EVAL_NAME \n",
277
+ )
278
+ f.write(
279
+ f"gridFM predict --model_exp_id {run.info.experiment_id} --model_run_id {run.info.run_id} --model_name best_model --config {config_dest} --eval_name YOUR_EVAL_NAME",
280
+ )
281
+
282
+
283
+ def main_standard(args, device):
284
+ exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
285
+ if args.exp is None:
286
+ experiment_name = f"exp_{exp_name}"
287
+ else:
288
+ experiment_name = f"{args.exp}"
289
+
290
+ mlflow.set_experiment(experiment_name)
291
+
292
+ if args.grid: # Only perform grid search if grid parameter file is provided
293
+ print(f"Grid search enabled. Using grid file: {args.grid}")
294
+
295
+ # Parse grid parameters
296
+ with open(args.grid, "r") as f:
297
+ grid_config = yaml.safe_load(f)
298
+
299
+ grid_combinations = param_combination_gen(grid_config)
300
+
301
+ # Run experiments for all combinations
302
+ for i, grid_params in enumerate(grid_combinations):
303
+ print(
304
+ f"\nGrid search: {i + 1}/{len(grid_combinations)} with params: {grid_params}",
305
+ )
306
+ with mlflow.start_run() as run:
307
+ run_training(
308
+ args.config,
309
+ grid_params,
310
+ args.data_path,
311
+ device,
312
+ run,
313
+ checkpoint_flag=False,
314
+ )
315
+ else:
316
+ print("No grid search config file provided. Running single training")
317
+ with mlflow.start_run() as run:
318
+ run_training(
319
+ args.config,
320
+ {},
321
+ args.data_path,
322
+ device,
323
+ run,
324
+ checkpoint_flag=False,
325
+ )
326
+
327
+
328
+ def main_checkpoint(args, device):
329
+ if args.grid:
330
+ warnings.warn("Grid search not supported with model checkpoint")
331
+ if args.config:
332
+ warnings.warn("No need to specify config file")
333
+
334
+ with mlflow.start_run(
335
+ experiment_id=args.model_exp_id,
336
+ run_id=args.model_run_id,
337
+ ) as run:
338
+ run_training(args.config, {}, args.data_path, device, run, checkpoint_flag=True)
339
+
340
+
341
+ def main_fine_tuning(args, device):
342
+ exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
343
+ if args.exp is None:
344
+ experiment_name = f"exp_{exp_name}"
345
+ else:
346
+ experiment_name = f"{args.exp}"
347
+
348
+ mlflow.set_experiment(experiment_name)
349
+ with mlflow.start_run() as run:
350
+ run_training(
351
+ args.config,
352
+ {},
353
+ args.data_path,
354
+ device,
355
+ run,
356
+ checkpoint_flag=False,
357
+ model_path=args.model_path,
358
+ )
359
+
360
+
361
+ def eval(model_path, run, config_path, data_path, device):
362
+ model = torch.load(model_path, weights_only=False, map_location=device).to(device)
363
+
364
+ artifact_dir = os.path.join(
365
+ "mlruns",
366
+ run.info.experiment_id,
367
+ run.info.run_id,
368
+ "artifacts",
369
+ )
370
+
371
+ # Define log directories
372
+ config_dir = os.path.join(artifact_dir, "config")
373
+ data_dir = os.path.join(artifact_dir, "data_idx")
374
+ test_dir = os.path.join(artifact_dir, "test")
375
+
376
+ # Create log directories if they don't exist
377
+ os.makedirs(config_dir, exist_ok=True)
378
+ os.makedirs(data_dir, exist_ok=True)
379
+ os.makedirs(test_dir, exist_ok=True)
380
+
381
+ # Load the base config
382
+ with open(config_path, "r") as f:
383
+ base_config = yaml.safe_load(f)
384
+
385
+ # Save config file
386
+ config_dest = os.path.join(config_dir, "config.yaml")
387
+ with open(config_dest, "w") as f:
388
+ yaml.dump(base_config, f)
389
+
390
+ args = NestedNamespace(**base_config)
391
+
392
+ # Fix random seed
393
+ torch.manual_seed(args.seed)
394
+ random.seed(args.seed)
395
+ np.random.seed(args.seed)
396
+
397
+ node_normalizers = []
398
+ edge_normalizers = []
399
+ datasets = []
400
+ test_datasets = []
401
+
402
+ for i, network in enumerate(args.data.networks):
403
+ node_normalizer, edge_normalizer = load_normalizer(args=args)
404
+ node_normalizers.append(node_normalizer)
405
+ edge_normalizers.append(edge_normalizer)
406
+
407
+ # Create torch dataset and split
408
+ data_path_network = os.path.join(data_path, network)
409
+ print(f"Loading {network} dataset")
410
+ dataset = GridDatasetMem(
411
+ root=data_path_network,
412
+ norm_method=args.data.normalization,
413
+ node_normalizer=node_normalizer,
414
+ edge_normalizer=edge_normalizer,
415
+ pe_dim=args.model.pe_dim,
416
+ mask_dim=args.data.mask_dim,
417
+ transform=get_transform(args=args),
418
+ )
419
+ datasets.append(dataset)
420
+
421
+ num_scenarios = args.data.scenarios[i]
422
+ if num_scenarios > len(dataset):
423
+ warnings.warn(
424
+ f"Requested number of scenarios ({num_scenarios}) exceeds dataset size ({len(dataset)}). "
425
+ "Using the full dataset instead.",
426
+ )
427
+ num_scenarios = len(dataset)
428
+
429
+ subset_indices = list(range(num_scenarios))
430
+ dataset = Subset(dataset, subset_indices)
431
+
432
+ node_normalizer.to(device)
433
+ edge_normalizer.to(device)
434
+
435
+ _, _, test_dataset = split_dataset(
436
+ dataset,
437
+ data_dir,
438
+ args.data.val_ratio,
439
+ args.data.test_ratio,
440
+ )
441
+
442
+ test_datasets.append(test_dataset)
443
+
444
+ test_loaders = [
445
+ DataLoader(i, batch_size=args.training.batch_size, shuffle=False)
446
+ for i in test_datasets
447
+ ]
448
+
449
+ mlflow.log_params(args.flatten())
450
+ for i, network in enumerate(args.data.networks):
451
+ for task in ["PF", "OPF", "Reconstruction"]:
452
+ mask_ratio = getattr(
453
+ args.data,
454
+ "mask_ratio",
455
+ 0.5,
456
+ ) # Default to 0.5 if mask_ratio doesn't exist
457
+ df, figs = eval_node_level_task(
458
+ dataset=datasets[i],
459
+ model=model,
460
+ task=task,
461
+ test_loader=test_loaders[i],
462
+ mask_dim=args.data.mask_dim,
463
+ mask_ratio=mask_ratio,
464
+ node_normalizer=node_normalizers[i],
465
+ device=device,
466
+ plot_dist=args.verbose,
467
+ )
468
+ # Log metric results
469
+ df_path = os.path.join(test_dir, f"{task}_metrics_results_{network}.csv")
470
+ df.to_csv(df_path)
471
+
472
+ plot_paths = os.path.join(
473
+ test_dir,
474
+ f"{task}_evaluation_plots_{network}.html",
475
+ )
476
+ with open(plot_paths, "a") as f:
477
+ for fig in figs:
478
+ f.write(pio.to_html(fig, full_html=False, include_plotlyjs="cdn"))
479
+
480
+ # Log node and edge stats
481
+ log_file_path = os.path.join(artifact_dir, f"stats_{network}.log")
482
+
483
+ # Write the print statements to the log file
484
+ with open(log_file_path, "w") as log_file:
485
+ log_file.write("Dataset node_stats: " + str(datasets[i].node_stats) + "\n")
486
+ log_file.write("Dataset edge_stats: " + str(datasets[i].edge_stats) + "\n")
487
+
488
+
489
+ def main_eval(
490
+ args,
491
+ device,
492
+ ):
493
+ if args.model_path is None and (
494
+ (args.model_exp_id is None)
495
+ or (args.model_run_id is None)
496
+ or (args.model_name is None)
497
+ ):
498
+ raise ValueError(
499
+ "Either model_path or (model_exp_id, model_run_id, model_name) must be provided",
500
+ )
501
+ if args.model_path is not None:
502
+ mlflow.set_experiment(args.eval_name)
503
+ with mlflow.start_run() as run:
504
+ eval(args.model_path, run, args.config, args.data_path, device)
505
+
506
+ else:
507
+ # Start the parent run using the provided experiment ID and run ID
508
+ # This is necessary to create a child run
509
+ with mlflow.start_run(
510
+ experiment_id=args.model_exp_id,
511
+ run_id=args.model_run_id,
512
+ ) as parent_run:
513
+ # Start a nested run
514
+ with mlflow.start_run(
515
+ experiment_id=args.model_exp_id,
516
+ parent_run_id=args.model_run_id,
517
+ run_name=args.eval_name,
518
+ nested=True,
519
+ ) as nested_run:
520
+ # load model from parent run artifact dir
521
+ model_path = os.path.join(
522
+ "mlruns",
523
+ parent_run.info.experiment_id,
524
+ parent_run.info.run_id,
525
+ "artifacts",
526
+ "model",
527
+ args.model_name + ".pth",
528
+ )
529
+
530
+ eval(model_path, nested_run, args.config, args.data_path, device)
File without changes