gridfm-graphkit 0.0.2__tar.gz → 0.0.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/PKG-INFO +4 -2
  2. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit.egg-info/PKG-INFO +4 -2
  3. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit.egg-info/SOURCES.txt +1 -0
  4. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/pyproject.toml +4 -2
  5. gridfm_graphkit-0.0.3/tests/test_model_outputs.py +55 -0
  6. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/LICENSE +0 -0
  7. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/README.md +0 -0
  8. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit/__init__.py +0 -0
  9. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit/__main__.py +0 -0
  10. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit/cli.py +0 -0
  11. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit/datasets/__init__.py +0 -0
  12. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit/datasets/data_normalization.py +0 -0
  13. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit/datasets/globals.py +0 -0
  14. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit/datasets/powergrid.py +0 -0
  15. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit/datasets/transforms.py +0 -0
  16. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit/datasets/utils.py +0 -0
  17. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit/evaluation/__init__.py +0 -0
  18. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit/evaluation/node_level.py +0 -0
  19. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit/io/__init__.py +0 -0
  20. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit/io/param_handler.py +0 -0
  21. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit/models/__init__.py +0 -0
  22. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit/models/gps_transformer.py +0 -0
  23. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit/models/graphTransformer.py +0 -0
  24. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit/training/__init__.py +0 -0
  25. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit/training/callbacks.py +0 -0
  26. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit/training/plugins.py +0 -0
  27. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit/training/trainer.py +0 -0
  28. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit/utils/__init__.py +0 -0
  29. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit/utils/loss.py +0 -0
  30. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit/utils/visualization.py +0 -0
  31. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit.egg-info/dependency_links.txt +0 -0
  32. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit.egg-info/entry_points.txt +0 -0
  33. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit.egg-info/requires.txt +0 -0
  34. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/gridfm_graphkit.egg-info/top_level.txt +0 -0
  35. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/setup.cfg +0 -0
  36. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/tests/test_training.py +0 -0
  37. {gridfm_graphkit-0.0.2 → gridfm_graphkit-0.0.3}/tests/test_yaml_configs.py +0 -0
@@ -1,15 +1,17 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gridfm-graphkit
3
- Version: 0.0.2
3
+ Version: 0.0.3
4
4
  Summary: Grid Foundation Model
5
5
  Author-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>, Alban Puech <apuech@seas.harvard.edu>, Tamara Govindasamy <tamara.govindasamy@ibm.com>, Mangaliso Mngomezulu <mngomezulum@ibm.com>, Etienne Vos <etienne.vos@ibm.com>, Celia Cintas <celia.cintas@ibm.com>, Jonas Weiss <jwe@zurich.ibm.com>
6
6
  Maintainer-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>
7
7
  License-Expression: Apache-2.0
8
8
  Keywords: electric power grid,foundational model,graph neural networks
9
9
  Classifier: Development Status :: 2 - Pre-Alpha
10
+ Classifier: Programming Language :: Python :: 3.10
11
+ Classifier: Programming Language :: Python :: 3.11
10
12
  Classifier: Programming Language :: Python :: 3.12
11
13
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
12
- Requires-Python: >=3.12.10
14
+ Requires-Python: <3.13,>=3.10
13
15
  Description-Content-Type: text/markdown
14
16
  License-File: LICENSE
15
17
  Requires-Dist: mlflow>=3.1.0
@@ -1,15 +1,17 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gridfm-graphkit
3
- Version: 0.0.2
3
+ Version: 0.0.3
4
4
  Summary: Grid Foundation Model
5
5
  Author-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>, Alban Puech <apuech@seas.harvard.edu>, Tamara Govindasamy <tamara.govindasamy@ibm.com>, Mangaliso Mngomezulu <mngomezulum@ibm.com>, Etienne Vos <etienne.vos@ibm.com>, Celia Cintas <celia.cintas@ibm.com>, Jonas Weiss <jwe@zurich.ibm.com>
6
6
  Maintainer-email: Matteo Mazzonelli <matteo.mazzonelli1@ibm.com>
7
7
  License-Expression: Apache-2.0
8
8
  Keywords: electric power grid,foundational model,graph neural networks
9
9
  Classifier: Development Status :: 2 - Pre-Alpha
10
+ Classifier: Programming Language :: Python :: 3.10
11
+ Classifier: Programming Language :: Python :: 3.11
10
12
  Classifier: Programming Language :: Python :: 3.12
11
13
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
12
- Requires-Python: >=3.12.10
14
+ Requires-Python: <3.13,>=3.10
13
15
  Description-Content-Type: text/markdown
14
16
  License-File: LICENSE
15
17
  Requires-Dist: mlflow>=3.1.0
@@ -30,5 +30,6 @@ gridfm_graphkit/training/trainer.py
30
30
  gridfm_graphkit/utils/__init__.py
31
31
  gridfm_graphkit/utils/loss.py
32
32
  gridfm_graphkit/utils/visualization.py
33
+ tests/test_model_outputs.py
33
34
  tests/test_training.py
34
35
  tests/test_yaml_configs.py
@@ -9,10 +9,10 @@ namespaces = false
9
9
  [project]
10
10
  name = "gridfm-graphkit"
11
11
  description = "Grid Foundation Model"
12
- version = "0.0.2"
12
+ version = "0.0.3"
13
13
  readme = "README.md"
14
14
  license = "Apache-2.0"
15
- requires-python = ">=3.12.10"
15
+ requires-python = ">=3.10,<3.13"
16
16
 
17
17
  authors = [
18
18
  {name = "Matteo Mazzonelli", email = "matteo.mazzonelli1@ibm.com"},
@@ -32,6 +32,8 @@ keywords = ["electric power grid", "foundational model", "graph neural networks"
32
32
 
33
33
  classifiers = [
34
34
  "Development Status :: 2 - Pre-Alpha",
35
+ "Programming Language :: Python :: 3.10",
36
+ "Programming Language :: Python :: 3.11",
35
37
  "Programming Language :: Python :: 3.12",
36
38
  "Topic :: Scientific/Engineering :: Artificial Intelligence"
37
39
  ]
@@ -0,0 +1,55 @@
1
+ import torch
2
+ import numpy as np
3
+ import pytest
4
+
5
+ # Device setup
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ # Input shape config
9
+ num_nodes = 1
10
+ x_dim = 9
11
+ pe_dim = 20
12
+ edge_attr_dim = 2
13
+
14
+ # List of models and reference files to check
15
+ models_to_test = [
16
+ (
17
+ "v0_1_2",
18
+ "examples/models/GridFM_v0_1_2.pth",
19
+ "tests/data/reference_output_v0_1_2.npy",
20
+ ),
21
+ (
22
+ "v0_2_3",
23
+ "examples/models/GridFM_v0_2_3.pth",
24
+ "tests/data/reference_output_v0_2_3.npy",
25
+ ),
26
+ ]
27
+
28
+
29
+ @pytest.mark.parametrize("version, model_path, ref_output_path", models_to_test)
30
+ def test_model_matches_reference(version, model_path, ref_output_path):
31
+ torch.manual_seed(0)
32
+
33
+ # Prepare zero input
34
+ x = torch.zeros((num_nodes, x_dim), device=device)
35
+ pe = torch.zeros((num_nodes, pe_dim), device=device)
36
+ edge_index = torch.tensor([[0], [0]], device=device)
37
+ edge_attr = torch.zeros((1, edge_attr_dim), device=device)
38
+ batch = torch.zeros(num_nodes, dtype=torch.long, device=device)
39
+
40
+ # Load model
41
+ model = torch.load(model_path, weights_only=False, map_location=device).to(device)
42
+ model.eval()
43
+
44
+ # Get current output
45
+ with torch.no_grad():
46
+ output = model(x, pe, edge_index, edge_attr, batch).cpu().numpy()
47
+
48
+ # Load saved reference
49
+ reference = np.load(ref_output_path)
50
+
51
+ # Exact match assertion
52
+ assert np.allclose(output, reference, rtol=1e-5, atol=1e-6), (
53
+ f"Model output for {version} does not match reference within tolerance.\n"
54
+ f"Max absolute difference: {np.max(np.abs(output - reference))}"
55
+ )
File without changes