gridfm-graphkit 0.0.6__tar.gz → 0.0.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gridfm_graphkit-0.0.7/PKG-INFO +343 -0
- gridfm_graphkit-0.0.7/README.md +300 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/__init__.py +8 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/__main__.py +400 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/cli.py +400 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/datasets/__init__.py +15 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/datasets/globals.py +54 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py +471 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/datasets/masking.py +325 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/datasets/normalizers.py +618 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/datasets/powergrid_hetero_dataset.py +258 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/datasets/task_transforms.py +60 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/datasets/transforms.py +126 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/datasets/utils.py +121 -0
- {gridfm_graphkit-0.0.6 → gridfm_graphkit-0.0.7}/gridfm_graphkit/io/param_handler.py +53 -27
- {gridfm_graphkit-0.0.6 → gridfm_graphkit-0.0.7}/gridfm_graphkit/io/registries.py +6 -1
- gridfm_graphkit-0.0.7/gridfm_graphkit/models/__init__.py +13 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/models/gnn_heterogeneous_gns.py +297 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/models/utils.py +192 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/tasks/__init__.py +5 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/tasks/base_task.py +133 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/tasks/opf_ac_dc_baseline.py +313 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/tasks/opf_task.py +612 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/tasks/pf_ac_dc_baseline.py +230 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/tasks/pf_task.py +431 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/tasks/reconstruction_tasks.py +104 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/tasks/se_task.py +187 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/tasks/utils.py +194 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/training/__init__.py +16 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/training/callbacks.py +94 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit/training/loss.py +394 -0
- gridfm_graphkit-0.0.7/gridfm_graphkit.egg-info/PKG-INFO +343 -0
- {gridfm_graphkit-0.0.6 → gridfm_graphkit-0.0.7}/gridfm_graphkit.egg-info/SOURCES.txt +17 -11
- gridfm_graphkit-0.0.7/gridfm_graphkit.egg-info/requires.txt +27 -0
- {gridfm_graphkit-0.0.6 → gridfm_graphkit-0.0.7}/pyproject.toml +15 -10
- {gridfm_graphkit-0.0.6 → gridfm_graphkit-0.0.7}/tests/test_data_module.py +6 -8
- gridfm_graphkit-0.0.7/tests/test_edge_flows.py +119 -0
- gridfm_graphkit-0.0.7/tests/test_losses.py +77 -0
- gridfm_graphkit-0.0.7/tests/test_pipeline.py +106 -0
- gridfm_graphkit-0.0.7/tests/test_simulate_measurements.py +97 -0
- gridfm_graphkit-0.0.7/tests/test_yaml_configs.py +27 -0
- gridfm_graphkit-0.0.6/PKG-INFO +0 -184
- gridfm_graphkit-0.0.6/README.md +0 -146
- gridfm_graphkit-0.0.6/gridfm_graphkit/__init__.py +0 -6
- gridfm_graphkit-0.0.6/gridfm_graphkit/__main__.py +0 -58
- gridfm_graphkit-0.0.6/gridfm_graphkit/cli.py +0 -134
- gridfm_graphkit-0.0.6/gridfm_graphkit/datasets/__init__.py +0 -23
- gridfm_graphkit-0.0.6/gridfm_graphkit/datasets/globals.py +0 -19
- gridfm_graphkit-0.0.6/gridfm_graphkit/datasets/normalizers.py +0 -279
- gridfm_graphkit-0.0.6/gridfm_graphkit/datasets/powergrid_datamodule.py +0 -207
- gridfm_graphkit-0.0.6/gridfm_graphkit/datasets/powergrid_dataset.py +0 -227
- gridfm_graphkit-0.0.6/gridfm_graphkit/datasets/transforms.py +0 -237
- gridfm_graphkit-0.0.6/gridfm_graphkit/datasets/utils.py +0 -49
- gridfm_graphkit-0.0.6/gridfm_graphkit/models/__init__.py +0 -4
- gridfm_graphkit-0.0.6/gridfm_graphkit/models/gnn_transformer.py +0 -96
- gridfm_graphkit-0.0.6/gridfm_graphkit/models/gps_transformer.py +0 -140
- gridfm_graphkit-0.0.6/gridfm_graphkit/tasks/__init__.py +0 -0
- gridfm_graphkit-0.0.6/gridfm_graphkit/tasks/feature_reconstruction_task.py +0 -366
- gridfm_graphkit-0.0.6/gridfm_graphkit/training/__init__.py +0 -0
- gridfm_graphkit-0.0.6/gridfm_graphkit/training/callbacks.py +0 -49
- gridfm_graphkit-0.0.6/gridfm_graphkit/training/loss.py +0 -198
- gridfm_graphkit-0.0.6/gridfm_graphkit/utils/__init__.py +0 -0
- gridfm_graphkit-0.0.6/gridfm_graphkit/utils/utils.py +0 -42
- gridfm_graphkit-0.0.6/gridfm_graphkit/utils/visualization.py +0 -513
- gridfm_graphkit-0.0.6/gridfm_graphkit.egg-info/PKG-INFO +0 -184
- gridfm_graphkit-0.0.6/gridfm_graphkit.egg-info/requires.txt +0 -22
- gridfm_graphkit-0.0.6/tests/test_full_pipeline.py +0 -82
- gridfm_graphkit-0.0.6/tests/test_losses.py +0 -36
- gridfm_graphkit-0.0.6/tests/test_model_outputs.py +0 -69
- gridfm_graphkit-0.0.6/tests/test_normalization.py +0 -36
- gridfm_graphkit-0.0.6/tests/test_yaml_configs.py +0 -48
- {gridfm_graphkit-0.0.6 → gridfm_graphkit-0.0.7}/LICENSE +0 -0
- {gridfm_graphkit-0.0.6 → gridfm_graphkit-0.0.7}/gridfm_graphkit/datasets/postprocessing.py +0 -0
- {gridfm_graphkit-0.0.6 → gridfm_graphkit-0.0.7}/gridfm_graphkit/io/__init__.py +0 -0
- {gridfm_graphkit-0.0.6 → gridfm_graphkit-0.0.7}/gridfm_graphkit.egg-info/dependency_links.txt +0 -0
- {gridfm_graphkit-0.0.6 → gridfm_graphkit-0.0.7}/gridfm_graphkit.egg-info/entry_points.txt +0 -0
- {gridfm_graphkit-0.0.6 → gridfm_graphkit-0.0.7}/gridfm_graphkit.egg-info/top_level.txt +0 -0
- {gridfm_graphkit-0.0.6 → gridfm_graphkit-0.0.7}/setup.cfg +0 -0
|
@@ -0,0 +1,343 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: gridfm-graphkit
|
|
3
|
+
Version: 0.0.7
|
|
4
|
+
Summary: Grid Foundation Model
|
|
5
|
+
Author-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>, Alban Puech <apuech@seas.harvard.edu>, Tamara Govindasamy <tamara.govindasamy@ibm.com>, Mangaliso Mngomezulu <mngomezulum@ibm.com>, Etienne Vos <etienne.vos@ibm.com>, Celia Cintas <celia.cintas@ibm.com>, Jonas Weiss <jwe@zurich.ibm.com>
|
|
6
|
+
Maintainer-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>
|
|
7
|
+
License-Expression: Apache-2.0
|
|
8
|
+
Keywords: electric power grid,foundational model,graph neural networks
|
|
9
|
+
Classifier: Development Status :: 2 - Pre-Alpha
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
14
|
+
Requires-Python: <3.13,>=3.10
|
|
15
|
+
Description-Content-Type: text/markdown
|
|
16
|
+
License-File: LICENSE
|
|
17
|
+
Requires-Dist: mlflow>=3.1.0
|
|
18
|
+
Requires-Dist: nbformat>=5.10.4
|
|
19
|
+
Requires-Dist: networkx>=3.4.2
|
|
20
|
+
Requires-Dist: numpy>=2.2.6
|
|
21
|
+
Requires-Dist: pandas>=2.3.0
|
|
22
|
+
Requires-Dist: plotly>=6.1.2
|
|
23
|
+
Requires-Dist: pyyaml>=6.0.2
|
|
24
|
+
Requires-Dist: torch<2.9,>=2.7.1
|
|
25
|
+
Requires-Dist: torch-geometric>=2.6.1
|
|
26
|
+
Requires-Dist: torchaudio>=2.7.1
|
|
27
|
+
Requires-Dist: torchvision>=0.22.1
|
|
28
|
+
Requires-Dist: lightning
|
|
29
|
+
Requires-Dist: seaborn
|
|
30
|
+
Requires-Dist: urllib3>=2.6.0
|
|
31
|
+
Requires-Dist: gdown>=6.0.0
|
|
32
|
+
Requires-Dist: gridfm-datakit>=1.0.2
|
|
33
|
+
Provides-Extra: dev
|
|
34
|
+
Requires-Dist: mkdocs-material; extra == "dev"
|
|
35
|
+
Requires-Dist: mkdocstrings[python]; extra == "dev"
|
|
36
|
+
Requires-Dist: pre-commit; extra == "dev"
|
|
37
|
+
Requires-Dist: bandit; extra == "dev"
|
|
38
|
+
Requires-Dist: build; extra == "dev"
|
|
39
|
+
Provides-Extra: test
|
|
40
|
+
Requires-Dist: pytest; extra == "test"
|
|
41
|
+
Requires-Dist: pytest-cov; extra == "test"
|
|
42
|
+
Dynamic: license-file
|
|
43
|
+
|
|
44
|
+
<p align="center">
|
|
45
|
+
<img src="https://raw.githubusercontent.com/gridfm/gridfm-graphkit/refs/heads/main/docs/figs/KIT.png" alt="GridFM logo" style="width: 40%; height: auto;"/>
|
|
46
|
+
<br/>
|
|
47
|
+
</p>
|
|
48
|
+
|
|
49
|
+
<p align="center" style="font-size: 25px;">
|
|
50
|
+
<b>gridfm-graphkit</b>
|
|
51
|
+
</p>
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
[](https://doi.org/10.5281/zenodo.17016737)
|
|
55
|
+
[](https://gridfm.github.io/gridfm-graphkit/)
|
|
56
|
+

|
|
57
|
+
[](https://www.bestpractices.dev/projects/12802)
|
|
58
|
+
[](https://scorecard.dev/viewer/?uri=github.com/gridfm/gridfm-graphkit)
|
|
59
|
+

|
|
60
|
+

|
|
61
|
+
|
|
62
|
+
This library is brought to you by the GridFM team to train, finetune and interact with a foundation model for the electric power grid.
|
|
63
|
+
|
|
64
|
+
---
|
|
65
|
+
|
|
66
|
+
# Installation
|
|
67
|
+
|
|
68
|
+
Create and activate a virtual environment (make sure you use the right python version = 3.10, 3.11 or 3.12. I highly recommend 3.12)
|
|
69
|
+
```bash
|
|
70
|
+
python -m venv venv
|
|
71
|
+
source venv/bin/activate
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
Install gridfm-graphkit from PyPI
|
|
75
|
+
```bash
|
|
76
|
+
pip install gridfm-graphkit
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
**`torch-scatter` is a required dependency.** It cannot be bundled in `pyproject.toml` because the correct wheel depends on your PyTorch and CUDA versions, so it must be installed separately.
|
|
80
|
+
|
|
81
|
+
Get PyTorch + CUDA version for torch-scatter
|
|
82
|
+
```bash
|
|
83
|
+
TORCH_CUDA_VERSION=$(python -c "import torch; print(torch.__version__ + ('+cpu' if torch.version.cuda is None else ''))")
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
Install the correct torch-scatter wheel
|
|
87
|
+
```bash
|
|
88
|
+
pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH_CUDA_VERSION}.html
|
|
89
|
+
```
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
For documentation generation and unit testing, install with the optional `dev` and `test` extras:
|
|
93
|
+
|
|
94
|
+
```bash
|
|
95
|
+
pip install "gridfm-graphkit[dev,test]"
|
|
96
|
+
```
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
# CLI commands
|
|
100
|
+
|
|
101
|
+
Interface to train, fine-tune, evaluate, and run inference on GridFM models using YAML configs and MLflow tracking.
|
|
102
|
+
|
|
103
|
+
```bash
|
|
104
|
+
gridfm_graphkit <command> [OPTIONS]
|
|
105
|
+
```
|
|
106
|
+
|
|
107
|
+
Available commands:
|
|
108
|
+
|
|
109
|
+
* `train` - Train a new model from scratch
|
|
110
|
+
* `finetune` - Fine-tune an existing pre-trained model
|
|
111
|
+
* `evaluate` - Evaluate model performance on a dataset
|
|
112
|
+
* `predict` - Run inference and save predictions
|
|
113
|
+
|
|
114
|
+
---
|
|
115
|
+
|
|
116
|
+
## Training Models
|
|
117
|
+
|
|
118
|
+
```bash
|
|
119
|
+
gridfm_graphkit train --config path/to/config.yaml
|
|
120
|
+
```
|
|
121
|
+
|
|
122
|
+
### Arguments
|
|
123
|
+
|
|
124
|
+
| Argument | Type | Description | Default |
|
|
125
|
+
| -------- | ---- | ----------- | ------- |
|
|
126
|
+
| `--config` | `str` | **Required**. Path to the training configuration YAML file. | `None` |
|
|
127
|
+
| `--exp_name` | `str` | MLflow experiment name. | `timestamp` |
|
|
128
|
+
| `--run_name` | `str` | MLflow run name. | `run` |
|
|
129
|
+
| `--log_dir` | `str` | MLflow tracking/logging directory. | `mlruns` |
|
|
130
|
+
| `--data_path` | `str` | Root dataset directory. | `data` |
|
|
131
|
+
| `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` |
|
|
132
|
+
| `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` |
|
|
133
|
+
| `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` |
|
|
134
|
+
| `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` |
|
|
135
|
+
| `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` |
|
|
136
|
+
| `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` |
|
|
137
|
+
| `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` |
|
|
138
|
+
| `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` |
|
|
139
|
+
| `--compute_dc_ac_metrics` | `flag` | Compute ground-truth AC/DC power balance metrics on the test split. | `False` |
|
|
140
|
+
| `--mp_context` | `str` | DataLoader multiprocessing start method (`spawn`, `fork`, `forkserver`). Defaults to PyTorch's automatic choice. On Linux, `spawn` is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. | `None` |
|
|
141
|
+
|
|
142
|
+
### Examples
|
|
143
|
+
|
|
144
|
+
**Standard Training:**
|
|
145
|
+
|
|
146
|
+
```bash
|
|
147
|
+
gridfm_graphkit train --config examples/config/case30_ieee_base.yaml --data_path examples/data
|
|
148
|
+
```
|
|
149
|
+
|
|
150
|
+
---
|
|
151
|
+
|
|
152
|
+
## Fine-Tuning Models
|
|
153
|
+
|
|
154
|
+
```bash
|
|
155
|
+
gridfm_graphkit finetune --config path/to/config.yaml --model_path path/to/model.pt
|
|
156
|
+
```
|
|
157
|
+
|
|
158
|
+
### Arguments
|
|
159
|
+
|
|
160
|
+
| Argument | Type | Description | Default |
|
|
161
|
+
| -------- | ---- | ----------- | ------- |
|
|
162
|
+
| `--config` | `str` | **Required**. Fine-tuning configuration file. | `None` |
|
|
163
|
+
| `--model_path` | `str` | **Required**. Path to a pre-trained model state dict. | `None` |
|
|
164
|
+
| `--exp_name` | `str` | MLflow experiment name. | `timestamp` |
|
|
165
|
+
| `--run_name` | `str` | MLflow run name. | `run` |
|
|
166
|
+
| `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
|
|
167
|
+
| `--data_path` | `str` | Root dataset directory. | `data` |
|
|
168
|
+
| `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` |
|
|
169
|
+
| `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` |
|
|
170
|
+
| `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` |
|
|
171
|
+
| `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` |
|
|
172
|
+
| `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` |
|
|
173
|
+
| `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` |
|
|
174
|
+
| `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` |
|
|
175
|
+
| `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` |
|
|
176
|
+
| `--compute_dc_ac_metrics` | `flag` | Compute ground-truth AC/DC power balance metrics on the test split. | `False` |
|
|
177
|
+
| `--mp_context` | `str` | DataLoader multiprocessing start method (`spawn`, `fork`, `forkserver`). Defaults to PyTorch's automatic choice. On Linux, `spawn` is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. | `None` |
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
---
|
|
181
|
+
|
|
182
|
+
## Evaluating Models
|
|
183
|
+
|
|
184
|
+
```bash
|
|
185
|
+
gridfm_graphkit evaluate --config path/to/eval.yaml --model_path path/to/model.pt
|
|
186
|
+
```
|
|
187
|
+
|
|
188
|
+
### Arguments
|
|
189
|
+
|
|
190
|
+
| Argument | Type | Description | Default |
|
|
191
|
+
| -------- | ---- | ----------- | ------- |
|
|
192
|
+
| `--config` | `str` | **Required**. Path to evaluation config. | `None` |
|
|
193
|
+
| `--model_path` | `str` | Path to the trained model state dict. | `None` |
|
|
194
|
+
| `--normalizer_stats` | `str` | Path to `normalizer_stats.pt` from a training run. Restores `fit_on_train` normalizers from saved statistics instead of re-fitting on current split. | `None` |
|
|
195
|
+
| `--exp_name` | `str` | MLflow experiment name. | `timestamp` |
|
|
196
|
+
| `--run_name` | `str` | MLflow run name. | `run` |
|
|
197
|
+
| `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
|
|
198
|
+
| `--data_path` | `str` | Dataset directory. | `data` |
|
|
199
|
+
| `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` |
|
|
200
|
+
| `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` |
|
|
201
|
+
| `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` |
|
|
202
|
+
| `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` |
|
|
203
|
+
| `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` |
|
|
204
|
+
| `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` |
|
|
205
|
+
| `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` |
|
|
206
|
+
| `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` |
|
|
207
|
+
| `--compute_dc_ac_metrics` | `flag` | Compute ground-truth AC/DC power balance metrics on the test split. | `False` |
|
|
208
|
+
| `--save_output` | `flag` | Save predictions as `<grid_name>_predictions.parquet` under MLflow artifacts (`.../artifacts/test`). | `False` |
|
|
209
|
+
| `--mp_context` | `str` | DataLoader multiprocessing start method (`spawn`, `fork`, `forkserver`). Defaults to PyTorch's automatic choice. On Linux, `spawn` is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. | `None` |
|
|
210
|
+
|
|
211
|
+
### Example with saved normalizer stats
|
|
212
|
+
|
|
213
|
+
When evaluating a model on a dataset, you can pass the normalizer statistics from the original training run to ensure the same normalization parameters are used:
|
|
214
|
+
|
|
215
|
+
```bash
|
|
216
|
+
gridfm_graphkit evaluate \
|
|
217
|
+
--config examples/config/HGNS_PF_datakit_case118.yaml \
|
|
218
|
+
--model_path mlruns/<experiment_id>/<run_id>/artifacts/model/best_model_state_dict.pt \
|
|
219
|
+
--normalizer_stats mlruns/<experiment_id>/<run_id>/artifacts/stats/normalizer_stats.pt \
|
|
220
|
+
--data_path data
|
|
221
|
+
```
|
|
222
|
+
|
|
223
|
+
> **Note:** The `--normalizer_stats` flag only affects normalizers with `fit_strategy = "fit_on_train"` (e.g. `HeteroDataMVANormalizer`). Per-sample normalizers (`HeteroDataPerSampleMVANormalizer`) always recompute their statistics from the current dataset regardless of this flag.
|
|
224
|
+
|
|
225
|
+
---
|
|
226
|
+
|
|
227
|
+
## Running Predictions
|
|
228
|
+
|
|
229
|
+
```bash
|
|
230
|
+
gridfm_graphkit predict --config path/to/config.yaml --model_path path/to/model.pt
|
|
231
|
+
```
|
|
232
|
+
|
|
233
|
+
### Arguments
|
|
234
|
+
|
|
235
|
+
| Argument | Type | Description | Default |
|
|
236
|
+
| -------- | ---- | ----------- | ------- |
|
|
237
|
+
| `--config` | `str` | **Required**. Path to prediction config file. | `None` |
|
|
238
|
+
| `--model_path` | `str` | Path to trained model state dict. Optional; may be defined in config. | `None` |
|
|
239
|
+
| `--normalizer_stats` | `str` | Path to `normalizer_stats.pt` from a training run. Restores `fit_on_train` normalizers from saved statistics. | `None` |
|
|
240
|
+
| `--exp_name` | `str` | MLflow experiment name. | `timestamp` |
|
|
241
|
+
| `--run_name` | `str` | MLflow run name. | `run` |
|
|
242
|
+
| `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
|
|
243
|
+
| `--data_path` | `str` | Dataset directory. | `data` |
|
|
244
|
+
| `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` |
|
|
245
|
+
| `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` |
|
|
246
|
+
| `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` |
|
|
247
|
+
| `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` |
|
|
248
|
+
| `--output_path` | `str` | Directory where predictions are saved as `<grid_name>_predictions.parquet`. | `data` |
|
|
249
|
+
| `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` |
|
|
250
|
+
| `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` |
|
|
251
|
+
| `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` |
|
|
252
|
+
| `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` |
|
|
253
|
+
| `--mp_context` | `str` | DataLoader multiprocessing start method (`spawn`, `fork`, `forkserver`). Defaults to PyTorch's automatic choice. On Linux, `spawn` is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. | `None` |
|
|
254
|
+
|
|
255
|
+
---
|
|
256
|
+
|
|
257
|
+
## Benchmarking Dataloader Throughput
|
|
258
|
+
|
|
259
|
+
```bash
|
|
260
|
+
gridfm_graphkit benchmark --config path/to/config.yaml
|
|
261
|
+
```
|
|
262
|
+
|
|
263
|
+
### Arguments
|
|
264
|
+
|
|
265
|
+
| Argument | Type | Description | Default |
|
|
266
|
+
| -------- | ---- | ----------- | ------- |
|
|
267
|
+
| `--config` | `str` | **Required**. Path to configuration YAML file. | `None` |
|
|
268
|
+
| `--data_path` | `str` | Root dataset directory. | `data` |
|
|
269
|
+
| `--epochs` | `int` | Number of epochs to iterate through the train dataloader. | `3` |
|
|
270
|
+
| `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` |
|
|
271
|
+
| `--dataset_wrapper_cache_dir` | `str` | Directory for dataset wrapper disk cache. | `None` |
|
|
272
|
+
| `--num_workers` | `int` | Override `data.workers` from YAML. | `None` |
|
|
273
|
+
| `--plugins` | `list[str]` | Python packages to import for plugin registration. | `[]` |
|
|
274
|
+
| `--mp_context` | `str` | DataLoader multiprocessing start method (`spawn`, `fork`, `forkserver`). Defaults to PyTorch's automatic choice. On Linux, `spawn` is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. | `None` |
|
|
275
|
+
|
|
276
|
+
Use built-in help for full command details:
|
|
277
|
+
|
|
278
|
+
```bash
|
|
279
|
+
gridfm_graphkit --help
|
|
280
|
+
gridfm_graphkit <command> --help
|
|
281
|
+
```
|
|
282
|
+
|
|
283
|
+
---
|
|
284
|
+
|
|
285
|
+
## Running Tests
|
|
286
|
+
|
|
287
|
+
### Unit and Integration Tests
|
|
288
|
+
|
|
289
|
+
Install the test dependencies first (if not already done):
|
|
290
|
+
|
|
291
|
+
```bash
|
|
292
|
+
pip install -e .[dev,test]
|
|
293
|
+
```
|
|
294
|
+
|
|
295
|
+
Run the full unit test suite:
|
|
296
|
+
|
|
297
|
+
```bash
|
|
298
|
+
pytest ./tests
|
|
299
|
+
```
|
|
300
|
+
|
|
301
|
+
Run the base set integration tests:
|
|
302
|
+
|
|
303
|
+
```bash
|
|
304
|
+
pytest ./integrationtests/test_base_set.py
|
|
305
|
+
```
|
|
306
|
+
|
|
307
|
+
### Running Base Set Tests on an LSF Cluster (GPU)
|
|
308
|
+
|
|
309
|
+
To submit the base set integration tests as an interactive LSF job with GPU access, use `bsub`. Adjust the paths to match your environment:
|
|
310
|
+
|
|
311
|
+
```bash
|
|
312
|
+
bsub -gpu "num=1" \
|
|
313
|
+
-n 16 \
|
|
314
|
+
-R "rusage[mem=32GB] span[hosts=1]" \
|
|
315
|
+
-Is \
|
|
316
|
+
-J gridfm_base_set_tests \
|
|
317
|
+
/bin/bash -c "
|
|
318
|
+
cd /path/to/gridfm-graphkit && \
|
|
319
|
+
export PATH=/path/to/cuda/bin:\$PATH \
|
|
320
|
+
CUDA_HOME=/path/to/cuda \
|
|
321
|
+
LD_LIBRARY_PATH=/path/to/cuda/lib64:\$LD_LIBRARY_PATH && \
|
|
322
|
+
source /path/to/venv/bin/activate && \
|
|
323
|
+
pytest ./integrationtests/test_base_set.py
|
|
324
|
+
"
|
|
325
|
+
```
|
|
326
|
+
|
|
327
|
+
Key `bsub` options used above:
|
|
328
|
+
|
|
329
|
+
| Option | Description |
|
|
330
|
+
| ------ | ----------- |
|
|
331
|
+
| `-gpu "num=1"` | Request 1 GPU |
|
|
332
|
+
| `-n 16` | Request 16 CPU slots |
|
|
333
|
+
| `-R "rusage[mem=32GB] span[hosts=1]"` | Reserve 32 GB of memory on a single host |
|
|
334
|
+
| `-Is` | Run as an interactive shell session |
|
|
335
|
+
| `-J <job_name>` | Assign a name to the job |
|
|
336
|
+
|
|
337
|
+
**Concrete example** (adapt paths to your cluster setup):
|
|
338
|
+
|
|
339
|
+
```bash
|
|
340
|
+
bsub -gpu "num=1" -n 16 -R "rusage[mem=32GB] span[hosts=1]" -Is -J hpo_trial_190 /bin/bash -c "cd /dccstor/terratorch/users/rkie/gitco/gridfm-graphkit && export PATH=/opt/share/cuda-12.8.1/bin:\$PATH CUDA_HOME=/opt/share/cuda-12.8.1 LD_LIBRARY_PATH=/opt/share/cuda-12.8.1/lib64:\$LD_LIBRARY_PATH && source /u/rkie/venvs/venv_gridfm-graphkit/bin/activate && pytest ./integrationtests/test_base_set.py"
|
|
341
|
+
```
|
|
342
|
+
|
|
343
|
+
---
|
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
<p align="center">
|
|
2
|
+
<img src="https://raw.githubusercontent.com/gridfm/gridfm-graphkit/refs/heads/main/docs/figs/KIT.png" alt="GridFM logo" style="width: 40%; height: auto;"/>
|
|
3
|
+
<br/>
|
|
4
|
+
</p>
|
|
5
|
+
|
|
6
|
+
<p align="center" style="font-size: 25px;">
|
|
7
|
+
<b>gridfm-graphkit</b>
|
|
8
|
+
</p>
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
[](https://doi.org/10.5281/zenodo.17016737)
|
|
12
|
+
[](https://gridfm.github.io/gridfm-graphkit/)
|
|
13
|
+

|
|
14
|
+
[](https://www.bestpractices.dev/projects/12802)
|
|
15
|
+
[](https://scorecard.dev/viewer/?uri=github.com/gridfm/gridfm-graphkit)
|
|
16
|
+

|
|
17
|
+

|
|
18
|
+
|
|
19
|
+
This library is brought to you by the GridFM team to train, finetune and interact with a foundation model for the electric power grid.
|
|
20
|
+
|
|
21
|
+
---
|
|
22
|
+
|
|
23
|
+
# Installation
|
|
24
|
+
|
|
25
|
+
Create and activate a virtual environment (make sure you use the right python version = 3.10, 3.11 or 3.12. I highly recommend 3.12)
|
|
26
|
+
```bash
|
|
27
|
+
python -m venv venv
|
|
28
|
+
source venv/bin/activate
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
Install gridfm-graphkit from PyPI
|
|
32
|
+
```bash
|
|
33
|
+
pip install gridfm-graphkit
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
**`torch-scatter` is a required dependency.** It cannot be bundled in `pyproject.toml` because the correct wheel depends on your PyTorch and CUDA versions, so it must be installed separately.
|
|
37
|
+
|
|
38
|
+
Get PyTorch + CUDA version for torch-scatter
|
|
39
|
+
```bash
|
|
40
|
+
TORCH_CUDA_VERSION=$(python -c "import torch; print(torch.__version__ + ('+cpu' if torch.version.cuda is None else ''))")
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
Install the correct torch-scatter wheel
|
|
44
|
+
```bash
|
|
45
|
+
pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH_CUDA_VERSION}.html
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
For documentation generation and unit testing, install with the optional `dev` and `test` extras:
|
|
50
|
+
|
|
51
|
+
```bash
|
|
52
|
+
pip install "gridfm-graphkit[dev,test]"
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# CLI commands
|
|
57
|
+
|
|
58
|
+
Interface to train, fine-tune, evaluate, and run inference on GridFM models using YAML configs and MLflow tracking.
|
|
59
|
+
|
|
60
|
+
```bash
|
|
61
|
+
gridfm_graphkit <command> [OPTIONS]
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
Available commands:
|
|
65
|
+
|
|
66
|
+
* `train` - Train a new model from scratch
|
|
67
|
+
* `finetune` - Fine-tune an existing pre-trained model
|
|
68
|
+
* `evaluate` - Evaluate model performance on a dataset
|
|
69
|
+
* `predict` - Run inference and save predictions
|
|
70
|
+
|
|
71
|
+
---
|
|
72
|
+
|
|
73
|
+
## Training Models
|
|
74
|
+
|
|
75
|
+
```bash
|
|
76
|
+
gridfm_graphkit train --config path/to/config.yaml
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
### Arguments
|
|
80
|
+
|
|
81
|
+
| Argument | Type | Description | Default |
|
|
82
|
+
| -------- | ---- | ----------- | ------- |
|
|
83
|
+
| `--config` | `str` | **Required**. Path to the training configuration YAML file. | `None` |
|
|
84
|
+
| `--exp_name` | `str` | MLflow experiment name. | `timestamp` |
|
|
85
|
+
| `--run_name` | `str` | MLflow run name. | `run` |
|
|
86
|
+
| `--log_dir` | `str` | MLflow tracking/logging directory. | `mlruns` |
|
|
87
|
+
| `--data_path` | `str` | Root dataset directory. | `data` |
|
|
88
|
+
| `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` |
|
|
89
|
+
| `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` |
|
|
90
|
+
| `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` |
|
|
91
|
+
| `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` |
|
|
92
|
+
| `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` |
|
|
93
|
+
| `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` |
|
|
94
|
+
| `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` |
|
|
95
|
+
| `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` |
|
|
96
|
+
| `--compute_dc_ac_metrics` | `flag` | Compute ground-truth AC/DC power balance metrics on the test split. | `False` |
|
|
97
|
+
| `--mp_context` | `str` | DataLoader multiprocessing start method (`spawn`, `fork`, `forkserver`). Defaults to PyTorch's automatic choice. On Linux, `spawn` is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. | `None` |
|
|
98
|
+
|
|
99
|
+
### Examples
|
|
100
|
+
|
|
101
|
+
**Standard Training:**
|
|
102
|
+
|
|
103
|
+
```bash
|
|
104
|
+
gridfm_graphkit train --config examples/config/case30_ieee_base.yaml --data_path examples/data
|
|
105
|
+
```
|
|
106
|
+
|
|
107
|
+
---
|
|
108
|
+
|
|
109
|
+
## Fine-Tuning Models
|
|
110
|
+
|
|
111
|
+
```bash
|
|
112
|
+
gridfm_graphkit finetune --config path/to/config.yaml --model_path path/to/model.pt
|
|
113
|
+
```
|
|
114
|
+
|
|
115
|
+
### Arguments
|
|
116
|
+
|
|
117
|
+
| Argument | Type | Description | Default |
|
|
118
|
+
| -------- | ---- | ----------- | ------- |
|
|
119
|
+
| `--config` | `str` | **Required**. Fine-tuning configuration file. | `None` |
|
|
120
|
+
| `--model_path` | `str` | **Required**. Path to a pre-trained model state dict. | `None` |
|
|
121
|
+
| `--exp_name` | `str` | MLflow experiment name. | `timestamp` |
|
|
122
|
+
| `--run_name` | `str` | MLflow run name. | `run` |
|
|
123
|
+
| `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
|
|
124
|
+
| `--data_path` | `str` | Root dataset directory. | `data` |
|
|
125
|
+
| `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` |
|
|
126
|
+
| `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` |
|
|
127
|
+
| `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` |
|
|
128
|
+
| `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` |
|
|
129
|
+
| `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` |
|
|
130
|
+
| `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` |
|
|
131
|
+
| `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` |
|
|
132
|
+
| `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` |
|
|
133
|
+
| `--compute_dc_ac_metrics` | `flag` | Compute ground-truth AC/DC power balance metrics on the test split. | `False` |
|
|
134
|
+
| `--mp_context` | `str` | DataLoader multiprocessing start method (`spawn`, `fork`, `forkserver`). Defaults to PyTorch's automatic choice. On Linux, `spawn` is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. | `None` |
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
---
|
|
138
|
+
|
|
139
|
+
## Evaluating Models
|
|
140
|
+
|
|
141
|
+
```bash
|
|
142
|
+
gridfm_graphkit evaluate --config path/to/eval.yaml --model_path path/to/model.pt
|
|
143
|
+
```
|
|
144
|
+
|
|
145
|
+
### Arguments
|
|
146
|
+
|
|
147
|
+
| Argument | Type | Description | Default |
|
|
148
|
+
| -------- | ---- | ----------- | ------- |
|
|
149
|
+
| `--config` | `str` | **Required**. Path to evaluation config. | `None` |
|
|
150
|
+
| `--model_path` | `str` | Path to the trained model state dict. | `None` |
|
|
151
|
+
| `--normalizer_stats` | `str` | Path to `normalizer_stats.pt` from a training run. Restores `fit_on_train` normalizers from saved statistics instead of re-fitting on current split. | `None` |
|
|
152
|
+
| `--exp_name` | `str` | MLflow experiment name. | `timestamp` |
|
|
153
|
+
| `--run_name` | `str` | MLflow run name. | `run` |
|
|
154
|
+
| `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
|
|
155
|
+
| `--data_path` | `str` | Dataset directory. | `data` |
|
|
156
|
+
| `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` |
|
|
157
|
+
| `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` |
|
|
158
|
+
| `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` |
|
|
159
|
+
| `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` |
|
|
160
|
+
| `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` |
|
|
161
|
+
| `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` |
|
|
162
|
+
| `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` |
|
|
163
|
+
| `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` |
|
|
164
|
+
| `--compute_dc_ac_metrics` | `flag` | Compute ground-truth AC/DC power balance metrics on the test split. | `False` |
|
|
165
|
+
| `--save_output` | `flag` | Save predictions as `<grid_name>_predictions.parquet` under MLflow artifacts (`.../artifacts/test`). | `False` |
|
|
166
|
+
| `--mp_context` | `str` | DataLoader multiprocessing start method (`spawn`, `fork`, `forkserver`). Defaults to PyTorch's automatic choice. On Linux, `spawn` is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. | `None` |
|
|
167
|
+
|
|
168
|
+
### Example with saved normalizer stats
|
|
169
|
+
|
|
170
|
+
When evaluating a model on a dataset, you can pass the normalizer statistics from the original training run to ensure the same normalization parameters are used:
|
|
171
|
+
|
|
172
|
+
```bash
|
|
173
|
+
gridfm_graphkit evaluate \
|
|
174
|
+
--config examples/config/HGNS_PF_datakit_case118.yaml \
|
|
175
|
+
--model_path mlruns/<experiment_id>/<run_id>/artifacts/model/best_model_state_dict.pt \
|
|
176
|
+
--normalizer_stats mlruns/<experiment_id>/<run_id>/artifacts/stats/normalizer_stats.pt \
|
|
177
|
+
--data_path data
|
|
178
|
+
```
|
|
179
|
+
|
|
180
|
+
> **Note:** The `--normalizer_stats` flag only affects normalizers with `fit_strategy = "fit_on_train"` (e.g. `HeteroDataMVANormalizer`). Per-sample normalizers (`HeteroDataPerSampleMVANormalizer`) always recompute their statistics from the current dataset regardless of this flag.
|
|
181
|
+
|
|
182
|
+
---
|
|
183
|
+
|
|
184
|
+
## Running Predictions
|
|
185
|
+
|
|
186
|
+
```bash
|
|
187
|
+
gridfm_graphkit predict --config path/to/config.yaml --model_path path/to/model.pt
|
|
188
|
+
```
|
|
189
|
+
|
|
190
|
+
### Arguments
|
|
191
|
+
|
|
192
|
+
| Argument | Type | Description | Default |
|
|
193
|
+
| -------- | ---- | ----------- | ------- |
|
|
194
|
+
| `--config` | `str` | **Required**. Path to prediction config file. | `None` |
|
|
195
|
+
| `--model_path` | `str` | Path to trained model state dict. Optional; may be defined in config. | `None` |
|
|
196
|
+
| `--normalizer_stats` | `str` | Path to `normalizer_stats.pt` from a training run. Restores `fit_on_train` normalizers from saved statistics. | `None` |
|
|
197
|
+
| `--exp_name` | `str` | MLflow experiment name. | `timestamp` |
|
|
198
|
+
| `--run_name` | `str` | MLflow run name. | `run` |
|
|
199
|
+
| `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
|
|
200
|
+
| `--data_path` | `str` | Dataset directory. | `data` |
|
|
201
|
+
| `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` |
|
|
202
|
+
| `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` |
|
|
203
|
+
| `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` |
|
|
204
|
+
| `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` |
|
|
205
|
+
| `--output_path` | `str` | Directory where predictions are saved as `<grid_name>_predictions.parquet`. | `data` |
|
|
206
|
+
| `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` |
|
|
207
|
+
| `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` |
|
|
208
|
+
| `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` |
|
|
209
|
+
| `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` |
|
|
210
|
+
| `--mp_context` | `str` | DataLoader multiprocessing start method (`spawn`, `fork`, `forkserver`). Defaults to PyTorch's automatic choice. On Linux, `spawn` is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. | `None` |
|
|
211
|
+
|
|
212
|
+
---
|
|
213
|
+
|
|
214
|
+
## Benchmarking Dataloader Throughput
|
|
215
|
+
|
|
216
|
+
```bash
|
|
217
|
+
gridfm_graphkit benchmark --config path/to/config.yaml
|
|
218
|
+
```
|
|
219
|
+
|
|
220
|
+
### Arguments
|
|
221
|
+
|
|
222
|
+
| Argument | Type | Description | Default |
|
|
223
|
+
| -------- | ---- | ----------- | ------- |
|
|
224
|
+
| `--config` | `str` | **Required**. Path to configuration YAML file. | `None` |
|
|
225
|
+
| `--data_path` | `str` | Root dataset directory. | `data` |
|
|
226
|
+
| `--epochs` | `int` | Number of epochs to iterate through the train dataloader. | `3` |
|
|
227
|
+
| `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` |
|
|
228
|
+
| `--dataset_wrapper_cache_dir` | `str` | Directory for dataset wrapper disk cache. | `None` |
|
|
229
|
+
| `--num_workers` | `int` | Override `data.workers` from YAML. | `None` |
|
|
230
|
+
| `--plugins` | `list[str]` | Python packages to import for plugin registration. | `[]` |
|
|
231
|
+
| `--mp_context` | `str` | DataLoader multiprocessing start method (`spawn`, `fork`, `forkserver`). Defaults to PyTorch's automatic choice. On Linux, `spawn` is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. | `None` |
|
|
232
|
+
|
|
233
|
+
Use built-in help for full command details:
|
|
234
|
+
|
|
235
|
+
```bash
|
|
236
|
+
gridfm_graphkit --help
|
|
237
|
+
gridfm_graphkit <command> --help
|
|
238
|
+
```
|
|
239
|
+
|
|
240
|
+
---
|
|
241
|
+
|
|
242
|
+
## Running Tests
|
|
243
|
+
|
|
244
|
+
### Unit and Integration Tests
|
|
245
|
+
|
|
246
|
+
Install the test dependencies first (if not already done):
|
|
247
|
+
|
|
248
|
+
```bash
|
|
249
|
+
pip install -e .[dev,test]
|
|
250
|
+
```
|
|
251
|
+
|
|
252
|
+
Run the full unit test suite:
|
|
253
|
+
|
|
254
|
+
```bash
|
|
255
|
+
pytest ./tests
|
|
256
|
+
```
|
|
257
|
+
|
|
258
|
+
Run the base set integration tests:
|
|
259
|
+
|
|
260
|
+
```bash
|
|
261
|
+
pytest ./integrationtests/test_base_set.py
|
|
262
|
+
```
|
|
263
|
+
|
|
264
|
+
### Running Base Set Tests on an LSF Cluster (GPU)
|
|
265
|
+
|
|
266
|
+
To submit the base set integration tests as an interactive LSF job with GPU access, use `bsub`. Adjust the paths to match your environment:
|
|
267
|
+
|
|
268
|
+
```bash
|
|
269
|
+
bsub -gpu "num=1" \
|
|
270
|
+
-n 16 \
|
|
271
|
+
-R "rusage[mem=32GB] span[hosts=1]" \
|
|
272
|
+
-Is \
|
|
273
|
+
-J gridfm_base_set_tests \
|
|
274
|
+
/bin/bash -c "
|
|
275
|
+
cd /path/to/gridfm-graphkit && \
|
|
276
|
+
export PATH=/path/to/cuda/bin:\$PATH \
|
|
277
|
+
CUDA_HOME=/path/to/cuda \
|
|
278
|
+
LD_LIBRARY_PATH=/path/to/cuda/lib64:\$LD_LIBRARY_PATH && \
|
|
279
|
+
source /path/to/venv/bin/activate && \
|
|
280
|
+
pytest ./integrationtests/test_base_set.py
|
|
281
|
+
"
|
|
282
|
+
```
|
|
283
|
+
|
|
284
|
+
Key `bsub` options used above:
|
|
285
|
+
|
|
286
|
+
| Option | Description |
|
|
287
|
+
| ------ | ----------- |
|
|
288
|
+
| `-gpu "num=1"` | Request 1 GPU |
|
|
289
|
+
| `-n 16` | Request 16 CPU slots |
|
|
290
|
+
| `-R "rusage[mem=32GB] span[hosts=1]"` | Reserve 32 GB of memory on a single host |
|
|
291
|
+
| `-Is` | Run as an interactive shell session |
|
|
292
|
+
| `-J <job_name>` | Assign a name to the job |
|
|
293
|
+
|
|
294
|
+
**Concrete example** (adapt paths to your cluster setup):
|
|
295
|
+
|
|
296
|
+
```bash
|
|
297
|
+
bsub -gpu "num=1" -n 16 -R "rusage[mem=32GB] span[hosts=1]" -Is -J hpo_trial_190 /bin/bash -c "cd /dccstor/terratorch/users/rkie/gitco/gridfm-graphkit && export PATH=/opt/share/cuda-12.8.1/bin:\$PATH CUDA_HOME=/opt/share/cuda-12.8.1 LD_LIBRARY_PATH=/opt/share/cuda-12.8.1/lib64:\$LD_LIBRARY_PATH && source /u/rkie/venvs/venv_gridfm-graphkit/bin/activate && pytest ./integrationtests/test_base_set.py"
|
|
298
|
+
```
|
|
299
|
+
|
|
300
|
+
---
|