gridfm-graphkit 0.0.5__tar.gz → 0.0.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. gridfm_graphkit-0.0.7/PKG-INFO +343 -0
  2. gridfm_graphkit-0.0.7/README.md +300 -0
  3. gridfm_graphkit-0.0.7/gridfm_graphkit/__init__.py +8 -0
  4. gridfm_graphkit-0.0.7/gridfm_graphkit/__main__.py +400 -0
  5. gridfm_graphkit-0.0.7/gridfm_graphkit/cli.py +400 -0
  6. gridfm_graphkit-0.0.7/gridfm_graphkit/datasets/__init__.py +15 -0
  7. gridfm_graphkit-0.0.7/gridfm_graphkit/datasets/globals.py +54 -0
  8. gridfm_graphkit-0.0.7/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py +471 -0
  9. gridfm_graphkit-0.0.7/gridfm_graphkit/datasets/masking.py +325 -0
  10. gridfm_graphkit-0.0.7/gridfm_graphkit/datasets/normalizers.py +618 -0
  11. gridfm_graphkit-0.0.7/gridfm_graphkit/datasets/powergrid_hetero_dataset.py +258 -0
  12. gridfm_graphkit-0.0.7/gridfm_graphkit/datasets/task_transforms.py +60 -0
  13. gridfm_graphkit-0.0.7/gridfm_graphkit/datasets/transforms.py +126 -0
  14. gridfm_graphkit-0.0.7/gridfm_graphkit/datasets/utils.py +121 -0
  15. {gridfm_graphkit-0.0.5 → gridfm_graphkit-0.0.7}/gridfm_graphkit/io/param_handler.py +53 -27
  16. {gridfm_graphkit-0.0.5 → gridfm_graphkit-0.0.7}/gridfm_graphkit/io/registries.py +6 -1
  17. gridfm_graphkit-0.0.7/gridfm_graphkit/models/__init__.py +13 -0
  18. gridfm_graphkit-0.0.7/gridfm_graphkit/models/gnn_heterogeneous_gns.py +297 -0
  19. gridfm_graphkit-0.0.7/gridfm_graphkit/models/utils.py +192 -0
  20. gridfm_graphkit-0.0.7/gridfm_graphkit/tasks/__init__.py +5 -0
  21. gridfm_graphkit-0.0.7/gridfm_graphkit/tasks/base_task.py +133 -0
  22. gridfm_graphkit-0.0.7/gridfm_graphkit/tasks/opf_ac_dc_baseline.py +313 -0
  23. gridfm_graphkit-0.0.7/gridfm_graphkit/tasks/opf_task.py +612 -0
  24. gridfm_graphkit-0.0.7/gridfm_graphkit/tasks/pf_ac_dc_baseline.py +230 -0
  25. gridfm_graphkit-0.0.7/gridfm_graphkit/tasks/pf_task.py +431 -0
  26. gridfm_graphkit-0.0.7/gridfm_graphkit/tasks/reconstruction_tasks.py +104 -0
  27. gridfm_graphkit-0.0.7/gridfm_graphkit/tasks/se_task.py +187 -0
  28. gridfm_graphkit-0.0.7/gridfm_graphkit/tasks/utils.py +194 -0
  29. gridfm_graphkit-0.0.7/gridfm_graphkit/training/__init__.py +16 -0
  30. gridfm_graphkit-0.0.7/gridfm_graphkit/training/callbacks.py +94 -0
  31. gridfm_graphkit-0.0.7/gridfm_graphkit/training/loss.py +394 -0
  32. gridfm_graphkit-0.0.7/gridfm_graphkit.egg-info/PKG-INFO +343 -0
  33. {gridfm_graphkit-0.0.5 → gridfm_graphkit-0.0.7}/gridfm_graphkit.egg-info/SOURCES.txt +17 -11
  34. {gridfm_graphkit-0.0.5 → gridfm_graphkit-0.0.7}/gridfm_graphkit.egg-info/requires.txt +6 -3
  35. {gridfm_graphkit-0.0.5 → gridfm_graphkit-0.0.7}/pyproject.toml +7 -4
  36. {gridfm_graphkit-0.0.5 → gridfm_graphkit-0.0.7}/tests/test_data_module.py +6 -8
  37. gridfm_graphkit-0.0.7/tests/test_edge_flows.py +119 -0
  38. gridfm_graphkit-0.0.7/tests/test_losses.py +77 -0
  39. gridfm_graphkit-0.0.7/tests/test_pipeline.py +106 -0
  40. gridfm_graphkit-0.0.7/tests/test_simulate_measurements.py +97 -0
  41. gridfm_graphkit-0.0.7/tests/test_yaml_configs.py +27 -0
  42. gridfm_graphkit-0.0.5/PKG-INFO +0 -181
  43. gridfm_graphkit-0.0.5/README.md +0 -141
  44. gridfm_graphkit-0.0.5/gridfm_graphkit/__init__.py +0 -6
  45. gridfm_graphkit-0.0.5/gridfm_graphkit/__main__.py +0 -58
  46. gridfm_graphkit-0.0.5/gridfm_graphkit/cli.py +0 -134
  47. gridfm_graphkit-0.0.5/gridfm_graphkit/datasets/__init__.py +0 -23
  48. gridfm_graphkit-0.0.5/gridfm_graphkit/datasets/globals.py +0 -19
  49. gridfm_graphkit-0.0.5/gridfm_graphkit/datasets/normalizers.py +0 -279
  50. gridfm_graphkit-0.0.5/gridfm_graphkit/datasets/powergrid_datamodule.py +0 -207
  51. gridfm_graphkit-0.0.5/gridfm_graphkit/datasets/powergrid_dataset.py +0 -227
  52. gridfm_graphkit-0.0.5/gridfm_graphkit/datasets/transforms.py +0 -237
  53. gridfm_graphkit-0.0.5/gridfm_graphkit/datasets/utils.py +0 -49
  54. gridfm_graphkit-0.0.5/gridfm_graphkit/models/__init__.py +0 -4
  55. gridfm_graphkit-0.0.5/gridfm_graphkit/models/gnn_transformer.py +0 -96
  56. gridfm_graphkit-0.0.5/gridfm_graphkit/models/gps_transformer.py +0 -140
  57. gridfm_graphkit-0.0.5/gridfm_graphkit/tasks/__init__.py +0 -0
  58. gridfm_graphkit-0.0.5/gridfm_graphkit/tasks/feature_reconstruction_task.py +0 -366
  59. gridfm_graphkit-0.0.5/gridfm_graphkit/training/__init__.py +0 -0
  60. gridfm_graphkit-0.0.5/gridfm_graphkit/training/callbacks.py +0 -49
  61. gridfm_graphkit-0.0.5/gridfm_graphkit/training/loss.py +0 -198
  62. gridfm_graphkit-0.0.5/gridfm_graphkit/utils/__init__.py +0 -0
  63. gridfm_graphkit-0.0.5/gridfm_graphkit/utils/utils.py +0 -42
  64. gridfm_graphkit-0.0.5/gridfm_graphkit/utils/visualization.py +0 -513
  65. gridfm_graphkit-0.0.5/gridfm_graphkit.egg-info/PKG-INFO +0 -181
  66. gridfm_graphkit-0.0.5/tests/test_full_pipeline.py +0 -82
  67. gridfm_graphkit-0.0.5/tests/test_losses.py +0 -36
  68. gridfm_graphkit-0.0.5/tests/test_model_outputs.py +0 -69
  69. gridfm_graphkit-0.0.5/tests/test_normalization.py +0 -36
  70. gridfm_graphkit-0.0.5/tests/test_yaml_configs.py +0 -48
  71. {gridfm_graphkit-0.0.5 → gridfm_graphkit-0.0.7}/LICENSE +0 -0
  72. {gridfm_graphkit-0.0.5 → gridfm_graphkit-0.0.7}/gridfm_graphkit/datasets/postprocessing.py +0 -0
  73. {gridfm_graphkit-0.0.5 → gridfm_graphkit-0.0.7}/gridfm_graphkit/io/__init__.py +0 -0
  74. {gridfm_graphkit-0.0.5 → gridfm_graphkit-0.0.7}/gridfm_graphkit.egg-info/dependency_links.txt +0 -0
  75. {gridfm_graphkit-0.0.5 → gridfm_graphkit-0.0.7}/gridfm_graphkit.egg-info/entry_points.txt +0 -0
  76. {gridfm_graphkit-0.0.5 → gridfm_graphkit-0.0.7}/gridfm_graphkit.egg-info/top_level.txt +0 -0
  77. {gridfm_graphkit-0.0.5 → gridfm_graphkit-0.0.7}/setup.cfg +0 -0
@@ -0,0 +1,343 @@
1
+ Metadata-Version: 2.4
2
+ Name: gridfm-graphkit
3
+ Version: 0.0.7
4
+ Summary: Grid Foundation Model
5
+ Author-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>, Alban Puech <apuech@seas.harvard.edu>, Tamara Govindasamy <tamara.govindasamy@ibm.com>, Mangaliso Mngomezulu <mngomezulum@ibm.com>, Etienne Vos <etienne.vos@ibm.com>, Celia Cintas <celia.cintas@ibm.com>, Jonas Weiss <jwe@zurich.ibm.com>
6
+ Maintainer-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>
7
+ License-Expression: Apache-2.0
8
+ Keywords: electric power grid,foundational model,graph neural networks
9
+ Classifier: Development Status :: 2 - Pre-Alpha
10
+ Classifier: Programming Language :: Python :: 3.10
11
+ Classifier: Programming Language :: Python :: 3.11
12
+ Classifier: Programming Language :: Python :: 3.12
13
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
+ Requires-Python: <3.13,>=3.10
15
+ Description-Content-Type: text/markdown
16
+ License-File: LICENSE
17
+ Requires-Dist: mlflow>=3.1.0
18
+ Requires-Dist: nbformat>=5.10.4
19
+ Requires-Dist: networkx>=3.4.2
20
+ Requires-Dist: numpy>=2.2.6
21
+ Requires-Dist: pandas>=2.3.0
22
+ Requires-Dist: plotly>=6.1.2
23
+ Requires-Dist: pyyaml>=6.0.2
24
+ Requires-Dist: torch<2.9,>=2.7.1
25
+ Requires-Dist: torch-geometric>=2.6.1
26
+ Requires-Dist: torchaudio>=2.7.1
27
+ Requires-Dist: torchvision>=0.22.1
28
+ Requires-Dist: lightning
29
+ Requires-Dist: seaborn
30
+ Requires-Dist: urllib3>=2.6.0
31
+ Requires-Dist: gdown>=6.0.0
32
+ Requires-Dist: gridfm-datakit>=1.0.2
33
+ Provides-Extra: dev
34
+ Requires-Dist: mkdocs-material; extra == "dev"
35
+ Requires-Dist: mkdocstrings[python]; extra == "dev"
36
+ Requires-Dist: pre-commit; extra == "dev"
37
+ Requires-Dist: bandit; extra == "dev"
38
+ Requires-Dist: build; extra == "dev"
39
+ Provides-Extra: test
40
+ Requires-Dist: pytest; extra == "test"
41
+ Requires-Dist: pytest-cov; extra == "test"
42
+ Dynamic: license-file
43
+
44
+ <p align="center">
45
+ <img src="https://raw.githubusercontent.com/gridfm/gridfm-graphkit/refs/heads/main/docs/figs/KIT.png" alt="GridFM logo" style="width: 40%; height: auto;"/>
46
+ <br/>
47
+ </p>
48
+
49
+ <p align="center" style="font-size: 25px;">
50
+ <b>gridfm-graphkit</b>
51
+ </p>
52
+
53
+
54
+ [![DOI](https://zenodo.org/badge/1007159095.svg)](https://doi.org/10.5281/zenodo.17016737)
55
+ [![Docs](https://img.shields.io/badge/docs-available-brightgreen)](https://gridfm.github.io/gridfm-graphkit/)
56
+ ![Coverage](https://img.shields.io/badge/coverage-83%25-yellowgreen)
57
+ [![OpenSSF Best Practices](https://www.bestpractices.dev/projects/12802/badge)](https://www.bestpractices.dev/projects/12802)
58
+ [![OpenSSF Scorecard](https://api.scorecard.dev/projects/github.com/gridfm/gridfm-graphkit/badge)](https://scorecard.dev/viewer/?uri=github.com/gridfm/gridfm-graphkit)
59
+ ![Python](https://img.shields.io/badge/python-3.10%20%E2%80%93%203.12-blue)
60
+ ![License](https://img.shields.io/badge/license-Apache%202.0-blue)
61
+
62
+ This library is brought to you by the GridFM team to train, finetune and interact with a foundation model for the electric power grid.
63
+
64
+ ---
65
+
66
+ # Installation
67
+
68
+ Create and activate a virtual environment (make sure you use the right python version = 3.10, 3.11 or 3.12. I highly recommend 3.12)
69
+ ```bash
70
+ python -m venv venv
71
+ source venv/bin/activate
72
+ ```
73
+
74
+ Install gridfm-graphkit from PyPI
75
+ ```bash
76
+ pip install gridfm-graphkit
77
+ ```
78
+
79
+ **`torch-scatter` is a required dependency.** It cannot be bundled in `pyproject.toml` because the correct wheel depends on your PyTorch and CUDA versions, so it must be installed separately.
80
+
81
+ Get PyTorch + CUDA version for torch-scatter
82
+ ```bash
83
+ TORCH_CUDA_VERSION=$(python -c "import torch; print(torch.__version__ + ('+cpu' if torch.version.cuda is None else ''))")
84
+ ```
85
+
86
+ Install the correct torch-scatter wheel
87
+ ```bash
88
+ pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH_CUDA_VERSION}.html
89
+ ```
90
+
91
+
92
+ For documentation generation and unit testing, install with the optional `dev` and `test` extras:
93
+
94
+ ```bash
95
+ pip install "gridfm-graphkit[dev,test]"
96
+ ```
97
+
98
+
99
+ # CLI commands
100
+
101
+ Interface to train, fine-tune, evaluate, and run inference on GridFM models using YAML configs and MLflow tracking.
102
+
103
+ ```bash
104
+ gridfm_graphkit <command> [OPTIONS]
105
+ ```
106
+
107
+ Available commands:
108
+
109
+ * `train` - Train a new model from scratch
110
+ * `finetune` - Fine-tune an existing pre-trained model
111
+ * `evaluate` - Evaluate model performance on a dataset
112
+ * `predict` - Run inference and save predictions
113
+
114
+ ---
115
+
116
+ ## Training Models
117
+
118
+ ```bash
119
+ gridfm_graphkit train --config path/to/config.yaml
120
+ ```
121
+
122
+ ### Arguments
123
+
124
+ | Argument | Type | Description | Default |
125
+ | -------- | ---- | ----------- | ------- |
126
+ | `--config` | `str` | **Required**. Path to the training configuration YAML file. | `None` |
127
+ | `--exp_name` | `str` | MLflow experiment name. | `timestamp` |
128
+ | `--run_name` | `str` | MLflow run name. | `run` |
129
+ | `--log_dir` | `str` | MLflow tracking/logging directory. | `mlruns` |
130
+ | `--data_path` | `str` | Root dataset directory. | `data` |
131
+ | `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` |
132
+ | `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` |
133
+ | `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` |
134
+ | `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` |
135
+ | `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` |
136
+ | `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` |
137
+ | `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` |
138
+ | `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` |
139
+ | `--compute_dc_ac_metrics` | `flag` | Compute ground-truth AC/DC power balance metrics on the test split. | `False` |
140
+ | `--mp_context` | `str` | DataLoader multiprocessing start method (`spawn`, `fork`, `forkserver`). Defaults to PyTorch's automatic choice. On Linux, `spawn` is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. | `None` |
141
+
142
+ ### Examples
143
+
144
+ **Standard Training:**
145
+
146
+ ```bash
147
+ gridfm_graphkit train --config examples/config/case30_ieee_base.yaml --data_path examples/data
148
+ ```
149
+
150
+ ---
151
+
152
+ ## Fine-Tuning Models
153
+
154
+ ```bash
155
+ gridfm_graphkit finetune --config path/to/config.yaml --model_path path/to/model.pt
156
+ ```
157
+
158
+ ### Arguments
159
+
160
+ | Argument | Type | Description | Default |
161
+ | -------- | ---- | ----------- | ------- |
162
+ | `--config` | `str` | **Required**. Fine-tuning configuration file. | `None` |
163
+ | `--model_path` | `str` | **Required**. Path to a pre-trained model state dict. | `None` |
164
+ | `--exp_name` | `str` | MLflow experiment name. | `timestamp` |
165
+ | `--run_name` | `str` | MLflow run name. | `run` |
166
+ | `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
167
+ | `--data_path` | `str` | Root dataset directory. | `data` |
168
+ | `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` |
169
+ | `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` |
170
+ | `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` |
171
+ | `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` |
172
+ | `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` |
173
+ | `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` |
174
+ | `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` |
175
+ | `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` |
176
+ | `--compute_dc_ac_metrics` | `flag` | Compute ground-truth AC/DC power balance metrics on the test split. | `False` |
177
+ | `--mp_context` | `str` | DataLoader multiprocessing start method (`spawn`, `fork`, `forkserver`). Defaults to PyTorch's automatic choice. On Linux, `spawn` is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. | `None` |
178
+
179
+
180
+ ---
181
+
182
+ ## Evaluating Models
183
+
184
+ ```bash
185
+ gridfm_graphkit evaluate --config path/to/eval.yaml --model_path path/to/model.pt
186
+ ```
187
+
188
+ ### Arguments
189
+
190
+ | Argument | Type | Description | Default |
191
+ | -------- | ---- | ----------- | ------- |
192
+ | `--config` | `str` | **Required**. Path to evaluation config. | `None` |
193
+ | `--model_path` | `str` | Path to the trained model state dict. | `None` |
194
+ | `--normalizer_stats` | `str` | Path to `normalizer_stats.pt` from a training run. Restores `fit_on_train` normalizers from saved statistics instead of re-fitting on current split. | `None` |
195
+ | `--exp_name` | `str` | MLflow experiment name. | `timestamp` |
196
+ | `--run_name` | `str` | MLflow run name. | `run` |
197
+ | `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
198
+ | `--data_path` | `str` | Dataset directory. | `data` |
199
+ | `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` |
200
+ | `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` |
201
+ | `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` |
202
+ | `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` |
203
+ | `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` |
204
+ | `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` |
205
+ | `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` |
206
+ | `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` |
207
+ | `--compute_dc_ac_metrics` | `flag` | Compute ground-truth AC/DC power balance metrics on the test split. | `False` |
208
+ | `--save_output` | `flag` | Save predictions as `<grid_name>_predictions.parquet` under MLflow artifacts (`.../artifacts/test`). | `False` |
209
+ | `--mp_context` | `str` | DataLoader multiprocessing start method (`spawn`, `fork`, `forkserver`). Defaults to PyTorch's automatic choice. On Linux, `spawn` is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. | `None` |
210
+
211
+ ### Example with saved normalizer stats
212
+
213
+ When evaluating a model on a dataset, you can pass the normalizer statistics from the original training run to ensure the same normalization parameters are used:
214
+
215
+ ```bash
216
+ gridfm_graphkit evaluate \
217
+ --config examples/config/HGNS_PF_datakit_case118.yaml \
218
+ --model_path mlruns/<experiment_id>/<run_id>/artifacts/model/best_model_state_dict.pt \
219
+ --normalizer_stats mlruns/<experiment_id>/<run_id>/artifacts/stats/normalizer_stats.pt \
220
+ --data_path data
221
+ ```
222
+
223
+ > **Note:** The `--normalizer_stats` flag only affects normalizers with `fit_strategy = "fit_on_train"` (e.g. `HeteroDataMVANormalizer`). Per-sample normalizers (`HeteroDataPerSampleMVANormalizer`) always recompute their statistics from the current dataset regardless of this flag.
224
+
225
+ ---
226
+
227
+ ## Running Predictions
228
+
229
+ ```bash
230
+ gridfm_graphkit predict --config path/to/config.yaml --model_path path/to/model.pt
231
+ ```
232
+
233
+ ### Arguments
234
+
235
+ | Argument | Type | Description | Default |
236
+ | -------- | ---- | ----------- | ------- |
237
+ | `--config` | `str` | **Required**. Path to prediction config file. | `None` |
238
+ | `--model_path` | `str` | Path to trained model state dict. Optional; may be defined in config. | `None` |
239
+ | `--normalizer_stats` | `str` | Path to `normalizer_stats.pt` from a training run. Restores `fit_on_train` normalizers from saved statistics. | `None` |
240
+ | `--exp_name` | `str` | MLflow experiment name. | `timestamp` |
241
+ | `--run_name` | `str` | MLflow run name. | `run` |
242
+ | `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
243
+ | `--data_path` | `str` | Dataset directory. | `data` |
244
+ | `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` |
245
+ | `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` |
246
+ | `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` |
247
+ | `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` |
248
+ | `--output_path` | `str` | Directory where predictions are saved as `<grid_name>_predictions.parquet`. | `data` |
249
+ | `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` |
250
+ | `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` |
251
+ | `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` |
252
+ | `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` |
253
+ | `--mp_context` | `str` | DataLoader multiprocessing start method (`spawn`, `fork`, `forkserver`). Defaults to PyTorch's automatic choice. On Linux, `spawn` is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. | `None` |
254
+
255
+ ---
256
+
257
+ ## Benchmarking Dataloader Throughput
258
+
259
+ ```bash
260
+ gridfm_graphkit benchmark --config path/to/config.yaml
261
+ ```
262
+
263
+ ### Arguments
264
+
265
+ | Argument | Type | Description | Default |
266
+ | -------- | ---- | ----------- | ------- |
267
+ | `--config` | `str` | **Required**. Path to configuration YAML file. | `None` |
268
+ | `--data_path` | `str` | Root dataset directory. | `data` |
269
+ | `--epochs` | `int` | Number of epochs to iterate through the train dataloader. | `3` |
270
+ | `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` |
271
+ | `--dataset_wrapper_cache_dir` | `str` | Directory for dataset wrapper disk cache. | `None` |
272
+ | `--num_workers` | `int` | Override `data.workers` from YAML. | `None` |
273
+ | `--plugins` | `list[str]` | Python packages to import for plugin registration. | `[]` |
274
+ | `--mp_context` | `str` | DataLoader multiprocessing start method (`spawn`, `fork`, `forkserver`). Defaults to PyTorch's automatic choice. On Linux, `spawn` is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. | `None` |
275
+
276
+ Use built-in help for full command details:
277
+
278
+ ```bash
279
+ gridfm_graphkit --help
280
+ gridfm_graphkit <command> --help
281
+ ```
282
+
283
+ ---
284
+
285
+ ## Running Tests
286
+
287
+ ### Unit and Integration Tests
288
+
289
+ Install the test dependencies first (if not already done):
290
+
291
+ ```bash
292
+ pip install -e .[dev,test]
293
+ ```
294
+
295
+ Run the full unit test suite:
296
+
297
+ ```bash
298
+ pytest ./tests
299
+ ```
300
+
301
+ Run the base set integration tests:
302
+
303
+ ```bash
304
+ pytest ./integrationtests/test_base_set.py
305
+ ```
306
+
307
+ ### Running Base Set Tests on an LSF Cluster (GPU)
308
+
309
+ To submit the base set integration tests as an interactive LSF job with GPU access, use `bsub`. Adjust the paths to match your environment:
310
+
311
+ ```bash
312
+ bsub -gpu "num=1" \
313
+ -n 16 \
314
+ -R "rusage[mem=32GB] span[hosts=1]" \
315
+ -Is \
316
+ -J gridfm_base_set_tests \
317
+ /bin/bash -c "
318
+ cd /path/to/gridfm-graphkit && \
319
+ export PATH=/path/to/cuda/bin:\$PATH \
320
+ CUDA_HOME=/path/to/cuda \
321
+ LD_LIBRARY_PATH=/path/to/cuda/lib64:\$LD_LIBRARY_PATH && \
322
+ source /path/to/venv/bin/activate && \
323
+ pytest ./integrationtests/test_base_set.py
324
+ "
325
+ ```
326
+
327
+ Key `bsub` options used above:
328
+
329
+ | Option | Description |
330
+ | ------ | ----------- |
331
+ | `-gpu "num=1"` | Request 1 GPU |
332
+ | `-n 16` | Request 16 CPU slots |
333
+ | `-R "rusage[mem=32GB] span[hosts=1]"` | Reserve 32 GB of memory on a single host |
334
+ | `-Is` | Run as an interactive shell session |
335
+ | `-J <job_name>` | Assign a name to the job |
336
+
337
+ **Concrete example** (adapt paths to your cluster setup):
338
+
339
+ ```bash
340
+ bsub -gpu "num=1" -n 16 -R "rusage[mem=32GB] span[hosts=1]" -Is -J hpo_trial_190 /bin/bash -c "cd /dccstor/terratorch/users/rkie/gitco/gridfm-graphkit && export PATH=/opt/share/cuda-12.8.1/bin:\$PATH CUDA_HOME=/opt/share/cuda-12.8.1 LD_LIBRARY_PATH=/opt/share/cuda-12.8.1/lib64:\$LD_LIBRARY_PATH && source /u/rkie/venvs/venv_gridfm-graphkit/bin/activate && pytest ./integrationtests/test_base_set.py"
341
+ ```
342
+
343
+ ---
@@ -0,0 +1,300 @@
1
+ <p align="center">
2
+ <img src="https://raw.githubusercontent.com/gridfm/gridfm-graphkit/refs/heads/main/docs/figs/KIT.png" alt="GridFM logo" style="width: 40%; height: auto;"/>
3
+ <br/>
4
+ </p>
5
+
6
+ <p align="center" style="font-size: 25px;">
7
+ <b>gridfm-graphkit</b>
8
+ </p>
9
+
10
+
11
+ [![DOI](https://zenodo.org/badge/1007159095.svg)](https://doi.org/10.5281/zenodo.17016737)
12
+ [![Docs](https://img.shields.io/badge/docs-available-brightgreen)](https://gridfm.github.io/gridfm-graphkit/)
13
+ ![Coverage](https://img.shields.io/badge/coverage-83%25-yellowgreen)
14
+ [![OpenSSF Best Practices](https://www.bestpractices.dev/projects/12802/badge)](https://www.bestpractices.dev/projects/12802)
15
+ [![OpenSSF Scorecard](https://api.scorecard.dev/projects/github.com/gridfm/gridfm-graphkit/badge)](https://scorecard.dev/viewer/?uri=github.com/gridfm/gridfm-graphkit)
16
+ ![Python](https://img.shields.io/badge/python-3.10%20%E2%80%93%203.12-blue)
17
+ ![License](https://img.shields.io/badge/license-Apache%202.0-blue)
18
+
19
+ This library is brought to you by the GridFM team to train, finetune and interact with a foundation model for the electric power grid.
20
+
21
+ ---
22
+
23
+ # Installation
24
+
25
+ Create and activate a virtual environment (make sure you use the right python version = 3.10, 3.11 or 3.12. I highly recommend 3.12)
26
+ ```bash
27
+ python -m venv venv
28
+ source venv/bin/activate
29
+ ```
30
+
31
+ Install gridfm-graphkit from PyPI
32
+ ```bash
33
+ pip install gridfm-graphkit
34
+ ```
35
+
36
+ **`torch-scatter` is a required dependency.** It cannot be bundled in `pyproject.toml` because the correct wheel depends on your PyTorch and CUDA versions, so it must be installed separately.
37
+
38
+ Get PyTorch + CUDA version for torch-scatter
39
+ ```bash
40
+ TORCH_CUDA_VERSION=$(python -c "import torch; print(torch.__version__ + ('+cpu' if torch.version.cuda is None else ''))")
41
+ ```
42
+
43
+ Install the correct torch-scatter wheel
44
+ ```bash
45
+ pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH_CUDA_VERSION}.html
46
+ ```
47
+
48
+
49
+ For documentation generation and unit testing, install with the optional `dev` and `test` extras:
50
+
51
+ ```bash
52
+ pip install "gridfm-graphkit[dev,test]"
53
+ ```
54
+
55
+
56
+ # CLI commands
57
+
58
+ Interface to train, fine-tune, evaluate, and run inference on GridFM models using YAML configs and MLflow tracking.
59
+
60
+ ```bash
61
+ gridfm_graphkit <command> [OPTIONS]
62
+ ```
63
+
64
+ Available commands:
65
+
66
+ * `train` - Train a new model from scratch
67
+ * `finetune` - Fine-tune an existing pre-trained model
68
+ * `evaluate` - Evaluate model performance on a dataset
69
+ * `predict` - Run inference and save predictions
70
+
71
+ ---
72
+
73
+ ## Training Models
74
+
75
+ ```bash
76
+ gridfm_graphkit train --config path/to/config.yaml
77
+ ```
78
+
79
+ ### Arguments
80
+
81
+ | Argument | Type | Description | Default |
82
+ | -------- | ---- | ----------- | ------- |
83
+ | `--config` | `str` | **Required**. Path to the training configuration YAML file. | `None` |
84
+ | `--exp_name` | `str` | MLflow experiment name. | `timestamp` |
85
+ | `--run_name` | `str` | MLflow run name. | `run` |
86
+ | `--log_dir` | `str` | MLflow tracking/logging directory. | `mlruns` |
87
+ | `--data_path` | `str` | Root dataset directory. | `data` |
88
+ | `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` |
89
+ | `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` |
90
+ | `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` |
91
+ | `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` |
92
+ | `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` |
93
+ | `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` |
94
+ | `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` |
95
+ | `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` |
96
+ | `--compute_dc_ac_metrics` | `flag` | Compute ground-truth AC/DC power balance metrics on the test split. | `False` |
97
+ | `--mp_context` | `str` | DataLoader multiprocessing start method (`spawn`, `fork`, `forkserver`). Defaults to PyTorch's automatic choice. On Linux, `spawn` is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. | `None` |
98
+
99
+ ### Examples
100
+
101
+ **Standard Training:**
102
+
103
+ ```bash
104
+ gridfm_graphkit train --config examples/config/case30_ieee_base.yaml --data_path examples/data
105
+ ```
106
+
107
+ ---
108
+
109
+ ## Fine-Tuning Models
110
+
111
+ ```bash
112
+ gridfm_graphkit finetune --config path/to/config.yaml --model_path path/to/model.pt
113
+ ```
114
+
115
+ ### Arguments
116
+
117
+ | Argument | Type | Description | Default |
118
+ | -------- | ---- | ----------- | ------- |
119
+ | `--config` | `str` | **Required**. Fine-tuning configuration file. | `None` |
120
+ | `--model_path` | `str` | **Required**. Path to a pre-trained model state dict. | `None` |
121
+ | `--exp_name` | `str` | MLflow experiment name. | `timestamp` |
122
+ | `--run_name` | `str` | MLflow run name. | `run` |
123
+ | `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
124
+ | `--data_path` | `str` | Root dataset directory. | `data` |
125
+ | `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` |
126
+ | `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` |
127
+ | `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` |
128
+ | `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` |
129
+ | `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` |
130
+ | `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` |
131
+ | `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` |
132
+ | `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` |
133
+ | `--compute_dc_ac_metrics` | `flag` | Compute ground-truth AC/DC power balance metrics on the test split. | `False` |
134
+ | `--mp_context` | `str` | DataLoader multiprocessing start method (`spawn`, `fork`, `forkserver`). Defaults to PyTorch's automatic choice. On Linux, `spawn` is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. | `None` |
135
+
136
+
137
+ ---
138
+
139
+ ## Evaluating Models
140
+
141
+ ```bash
142
+ gridfm_graphkit evaluate --config path/to/eval.yaml --model_path path/to/model.pt
143
+ ```
144
+
145
+ ### Arguments
146
+
147
+ | Argument | Type | Description | Default |
148
+ | -------- | ---- | ----------- | ------- |
149
+ | `--config` | `str` | **Required**. Path to evaluation config. | `None` |
150
+ | `--model_path` | `str` | Path to the trained model state dict. | `None` |
151
+ | `--normalizer_stats` | `str` | Path to `normalizer_stats.pt` from a training run. Restores `fit_on_train` normalizers from saved statistics instead of re-fitting on current split. | `None` |
152
+ | `--exp_name` | `str` | MLflow experiment name. | `timestamp` |
153
+ | `--run_name` | `str` | MLflow run name. | `run` |
154
+ | `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
155
+ | `--data_path` | `str` | Dataset directory. | `data` |
156
+ | `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` |
157
+ | `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` |
158
+ | `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` |
159
+ | `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` |
160
+ | `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` |
161
+ | `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` |
162
+ | `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` |
163
+ | `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` |
164
+ | `--compute_dc_ac_metrics` | `flag` | Compute ground-truth AC/DC power balance metrics on the test split. | `False` |
165
+ | `--save_output` | `flag` | Save predictions as `<grid_name>_predictions.parquet` under MLflow artifacts (`.../artifacts/test`). | `False` |
166
+ | `--mp_context` | `str` | DataLoader multiprocessing start method (`spawn`, `fork`, `forkserver`). Defaults to PyTorch's automatic choice. On Linux, `spawn` is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. | `None` |
167
+
168
+ ### Example with saved normalizer stats
169
+
170
+ When evaluating a model on a dataset, you can pass the normalizer statistics from the original training run to ensure the same normalization parameters are used:
171
+
172
+ ```bash
173
+ gridfm_graphkit evaluate \
174
+ --config examples/config/HGNS_PF_datakit_case118.yaml \
175
+ --model_path mlruns/<experiment_id>/<run_id>/artifacts/model/best_model_state_dict.pt \
176
+ --normalizer_stats mlruns/<experiment_id>/<run_id>/artifacts/stats/normalizer_stats.pt \
177
+ --data_path data
178
+ ```
179
+
180
+ > **Note:** The `--normalizer_stats` flag only affects normalizers with `fit_strategy = "fit_on_train"` (e.g. `HeteroDataMVANormalizer`). Per-sample normalizers (`HeteroDataPerSampleMVANormalizer`) always recompute their statistics from the current dataset regardless of this flag.
181
+
182
+ ---
183
+
184
+ ## Running Predictions
185
+
186
+ ```bash
187
+ gridfm_graphkit predict --config path/to/config.yaml --model_path path/to/model.pt
188
+ ```
189
+
190
+ ### Arguments
191
+
192
+ | Argument | Type | Description | Default |
193
+ | -------- | ---- | ----------- | ------- |
194
+ | `--config` | `str` | **Required**. Path to prediction config file. | `None` |
195
+ | `--model_path` | `str` | Path to trained model state dict. Optional; may be defined in config. | `None` |
196
+ | `--normalizer_stats` | `str` | Path to `normalizer_stats.pt` from a training run. Restores `fit_on_train` normalizers from saved statistics. | `None` |
197
+ | `--exp_name` | `str` | MLflow experiment name. | `timestamp` |
198
+ | `--run_name` | `str` | MLflow run name. | `run` |
199
+ | `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
200
+ | `--data_path` | `str` | Dataset directory. | `data` |
201
+ | `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` |
202
+ | `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` |
203
+ | `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` |
204
+ | `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` |
205
+ | `--output_path` | `str` | Directory where predictions are saved as `<grid_name>_predictions.parquet`. | `data` |
206
+ | `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` |
207
+ | `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` |
208
+ | `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` |
209
+ | `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` |
210
+ | `--mp_context` | `str` | DataLoader multiprocessing start method (`spawn`, `fork`, `forkserver`). Defaults to PyTorch's automatic choice. On Linux, `spawn` is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. | `None` |
211
+
212
+ ---
213
+
214
+ ## Benchmarking Dataloader Throughput
215
+
216
+ ```bash
217
+ gridfm_graphkit benchmark --config path/to/config.yaml
218
+ ```
219
+
220
+ ### Arguments
221
+
222
+ | Argument | Type | Description | Default |
223
+ | -------- | ---- | ----------- | ------- |
224
+ | `--config` | `str` | **Required**. Path to configuration YAML file. | `None` |
225
+ | `--data_path` | `str` | Root dataset directory. | `data` |
226
+ | `--epochs` | `int` | Number of epochs to iterate through the train dataloader. | `3` |
227
+ | `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` |
228
+ | `--dataset_wrapper_cache_dir` | `str` | Directory for dataset wrapper disk cache. | `None` |
229
+ | `--num_workers` | `int` | Override `data.workers` from YAML. | `None` |
230
+ | `--plugins` | `list[str]` | Python packages to import for plugin registration. | `[]` |
231
+ | `--mp_context` | `str` | DataLoader multiprocessing start method (`spawn`, `fork`, `forkserver`). Defaults to PyTorch's automatic choice. On Linux, `spawn` is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. | `None` |
232
+
233
+ Use built-in help for full command details:
234
+
235
+ ```bash
236
+ gridfm_graphkit --help
237
+ gridfm_graphkit <command> --help
238
+ ```
239
+
240
+ ---
241
+
242
+ ## Running Tests
243
+
244
+ ### Unit and Integration Tests
245
+
246
+ Install the test dependencies first (if not already done):
247
+
248
+ ```bash
249
+ pip install -e .[dev,test]
250
+ ```
251
+
252
+ Run the full unit test suite:
253
+
254
+ ```bash
255
+ pytest ./tests
256
+ ```
257
+
258
+ Run the base set integration tests:
259
+
260
+ ```bash
261
+ pytest ./integrationtests/test_base_set.py
262
+ ```
263
+
264
+ ### Running Base Set Tests on an LSF Cluster (GPU)
265
+
266
+ To submit the base set integration tests as an interactive LSF job with GPU access, use `bsub`. Adjust the paths to match your environment:
267
+
268
+ ```bash
269
+ bsub -gpu "num=1" \
270
+ -n 16 \
271
+ -R "rusage[mem=32GB] span[hosts=1]" \
272
+ -Is \
273
+ -J gridfm_base_set_tests \
274
+ /bin/bash -c "
275
+ cd /path/to/gridfm-graphkit && \
276
+ export PATH=/path/to/cuda/bin:\$PATH \
277
+ CUDA_HOME=/path/to/cuda \
278
+ LD_LIBRARY_PATH=/path/to/cuda/lib64:\$LD_LIBRARY_PATH && \
279
+ source /path/to/venv/bin/activate && \
280
+ pytest ./integrationtests/test_base_set.py
281
+ "
282
+ ```
283
+
284
+ Key `bsub` options used above:
285
+
286
+ | Option | Description |
287
+ | ------ | ----------- |
288
+ | `-gpu "num=1"` | Request 1 GPU |
289
+ | `-n 16` | Request 16 CPU slots |
290
+ | `-R "rusage[mem=32GB] span[hosts=1]"` | Reserve 32 GB of memory on a single host |
291
+ | `-Is` | Run as an interactive shell session |
292
+ | `-J <job_name>` | Assign a name to the job |
293
+
294
+ **Concrete example** (adapt paths to your cluster setup):
295
+
296
+ ```bash
297
+ bsub -gpu "num=1" -n 16 -R "rusage[mem=32GB] span[hosts=1]" -Is -J hpo_trial_190 /bin/bash -c "cd /dccstor/terratorch/users/rkie/gitco/gridfm-graphkit && export PATH=/opt/share/cuda-12.8.1/bin:\$PATH CUDA_HOME=/opt/share/cuda-12.8.1 LD_LIBRARY_PATH=/opt/share/cuda-12.8.1/lib64:\$LD_LIBRARY_PATH && source /u/rkie/venvs/venv_gridfm-graphkit/bin/activate && pytest ./integrationtests/test_base_set.py"
298
+ ```
299
+
300
+ ---
@@ -0,0 +1,8 @@
1
+ import gridfm_graphkit.datasets
2
+ import gridfm_graphkit.tasks.base_task
3
+ import gridfm_graphkit.models.gnn_heterogeneous_gns
4
+ import gridfm_graphkit.tasks.reconstruction_tasks
5
+
6
+ __all__ = [
7
+ "gridfm_graphkit",
8
+ ]