gridfm-graphkit 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,227 @@
1
+ from gridfm_graphkit.datasets.globals import PD, QD, PG, QG, VA
2
+ import torch
3
+ from abc import ABC, abstractmethod
4
+
5
+
6
+ class Normalizer(ABC):
7
+ """
8
+ Abstract base class for all normalization strategies.
9
+ """
10
+
11
+ @abstractmethod
12
+ def fit(self, data: torch.Tensor) -> dict:
13
+ """
14
+ Fit normalization parameters from data.
15
+
16
+ Args:
17
+ data: Input tensor.
18
+
19
+ Returns:
20
+ Dictionary of computed parameters.
21
+ """
22
+
23
+ @abstractmethod
24
+ def fit_from_dict(self, params: dict):
25
+ """
26
+ Set parameters from a precomputed dictionary.
27
+
28
+ Args:
29
+ params: Dictionary of parameters.
30
+ """
31
+
32
+ @abstractmethod
33
+ def transform(self, data: torch.Tensor) -> torch.Tensor:
34
+ """
35
+ Normalize the input data.
36
+
37
+ Args:
38
+ data: Input tensor.
39
+
40
+ Returns:
41
+ Normalized tensor.
42
+ """
43
+
44
+ @abstractmethod
45
+ def inverse_transform(self, normalized_data: torch.Tensor) -> torch.Tensor:
46
+ """
47
+ Undo normalization.
48
+
49
+ Args:
50
+ normalized_data: Normalized tensor.
51
+
52
+ Returns:
53
+ Original tensor.
54
+ """
55
+
56
+
57
+ class MinMaxNormalizer(Normalizer):
58
+ """
59
+ Scales each feature to the [0, 1] range.
60
+ """
61
+
62
+ def __init__(self):
63
+ self.min_val = None
64
+ self.max_val = None
65
+
66
+ def to(self, device):
67
+ self.min_val = self.min_val.to(device)
68
+ self.max_val = self.max_val.to(device)
69
+
70
+ def fit(self, data: torch.Tensor) -> dict:
71
+ self.min_val, _ = data.min(axis=0)
72
+ self.max_val, _ = data.max(axis=0)
73
+
74
+ return {"min_value": self.min_val, "max_value": self.max_val}
75
+
76
+ def fit_from_dict(self, params: dict):
77
+ if self.min_val is None:
78
+ self.min_val = params.get("min_value")
79
+ if self.max_val is None:
80
+ self.max_val = params.get("max_value")
81
+
82
+ def transform(self, data: torch.Tensor) -> torch.Tensor:
83
+ if self.min_val is None or self.max_val is None:
84
+ raise ValueError("fit must be called before transform.")
85
+
86
+ diff = self.max_val - self.min_val
87
+ diff[diff == 0] = 1 # Avoid division by zero for features with zero range
88
+ return (data - self.min_val) / diff
89
+
90
+ def inverse_transform(self, normalized_data: torch.Tensor) -> torch.Tensor:
91
+ if self.min_val is None or self.max_val is None:
92
+ raise ValueError("fit must be called before inverse_transform.")
93
+
94
+ diff = self.max_val - self.min_val
95
+ diff[diff == 0] = 1
96
+ return (normalized_data * diff) + self.min_val
97
+
98
+
99
+ class Standardizer(Normalizer):
100
+ """
101
+ Standardizes each feature to zero mean and unit variance.
102
+ """
103
+
104
+ def __init__(self):
105
+ self.mean = None
106
+ self.std = None
107
+
108
+ def to(self, device):
109
+ self.mean = self.mean.to(device)
110
+ self.std = self.std.to(device)
111
+
112
+ def fit(self, data: torch.Tensor) -> dict:
113
+ self.mean = data.mean(axis=0)
114
+ self.std = data.std(axis=0)
115
+
116
+ return {"mean_value": self.mean, "std_value": self.std}
117
+
118
+ def fit_from_dict(self, params: dict):
119
+ if self.mean is None:
120
+ self.mean = params.get("mean_value")
121
+ if self.std is None:
122
+ self.std = params.get("std_value")
123
+
124
+ def transform(self, data: torch.Tensor) -> torch.Tensor:
125
+ if self.mean is None or self.std is None:
126
+ raise ValueError("fit must be called before transform.")
127
+
128
+ std = self.std.clone()
129
+ std[std == 0] = 1 # Avoid division by zero for features with zero std
130
+ return (data - self.mean) / std
131
+
132
+ def inverse_transform(self, normalized_data: torch.Tensor) -> torch.Tensor:
133
+ if self.mean is None or self.std is None:
134
+ raise ValueError("fit must be called before inverse_transform.")
135
+
136
+ std = self.std.clone()
137
+ std[std == 0] = 1
138
+ return (normalized_data * std) + self.mean
139
+
140
+
141
+ class BaseMVANormalizer(Normalizer):
142
+ """
143
+ In power systems, a suitable normalization strategy must preserve the physical properties of
144
+ the system. A known method is the conversion to the per-unit (p.u.) system, which expresses
145
+ electrical quantities such as voltage, current, power, and impedance as fractions of predefined
146
+ base values. These base values are usually chosen based on system parameters, such as rated
147
+ voltage. The per-unit conversion ensures that power system equations remain scale-invariant,
148
+ preserving fundamental physical relationships.
149
+ """
150
+
151
+ def __init__(self, node_data: bool, baseMVA_orig: float = 100.0):
152
+ """
153
+ Args:
154
+ node_data: Whether data is node-level or edge-level (PD, QD, PG, QG, VA).
155
+ baseMVA_orig: Original baseMVA (e.g. from MATPOWER).
156
+ """
157
+ self.node_data = node_data
158
+ self.baseMVA_orig = baseMVA_orig
159
+ self.baseMVA = None
160
+
161
+ def to(self, device):
162
+ pass
163
+
164
+ def fit(self, data: torch.Tensor, baseMVA: float = None) -> dict:
165
+ if self.node_data:
166
+ self.baseMVA = data[:, [PD, QD, PG, QG]].max()
167
+ else:
168
+ self.baseMVA = baseMVA
169
+
170
+ return {"baseMVA_orig": self.baseMVA_orig, "baseMVA": self.baseMVA}
171
+
172
+ def fit_from_dict(self, params: dict):
173
+ if self.baseMVA is None:
174
+ self.baseMVA = params.get("baseMVA")
175
+ if self.baseMVA_orig is None:
176
+ self.baseMVA_orig = params.get("baseMVA_orig")
177
+
178
+ def transform(self, data: torch.Tensor) -> torch.Tensor:
179
+ if self.baseMVA is None:
180
+ raise ValueError("BaseMVA is not specified")
181
+
182
+ if self.baseMVA == 0:
183
+ raise ZeroDivisionError("BaseMVA is 0.")
184
+
185
+ if self.node_data:
186
+ data[:, PD] = data[:, PD] / self.baseMVA
187
+ data[:, QD] = data[:, QD] / self.baseMVA
188
+ data[:, PG] = data[:, PG] / self.baseMVA
189
+ data[:, QG] = data[:, QG] / self.baseMVA
190
+ data[:, VA] = data[:, VA] * torch.pi / 180.0
191
+ else:
192
+ data = data * self.baseMVA_orig / self.baseMVA
193
+
194
+ return data
195
+
196
+ def inverse_transform(self, normalized_data: torch.Tensor) -> torch.Tensor:
197
+ if self.baseMVA is None:
198
+ raise ValueError("fit must be called before inverse_transform.")
199
+
200
+ if self.node_data:
201
+ normalized_data[:, PD] = normalized_data[:, PD] * self.baseMVA
202
+ normalized_data[:, QD] = normalized_data[:, QD] * self.baseMVA
203
+ normalized_data[:, PG] = normalized_data[:, PG] * self.baseMVA
204
+ normalized_data[:, QG] = normalized_data[:, QG] * self.baseMVA
205
+ normalized_data[:, VA] = normalized_data[:, VA] * 180.0 / torch.pi
206
+ else:
207
+ normalized_data = normalized_data * self.baseMVA / self.baseMVA_orig
208
+
209
+ return normalized_data
210
+
211
+
212
+ class IdentityNormalizer(Normalizer):
213
+ """
214
+ No normalization: returns data unchanged.
215
+ """
216
+
217
+ def fit(self, data: torch.Tensor) -> dict:
218
+ return {}
219
+
220
+ def fit_from_dict(self, params: dict):
221
+ pass
222
+
223
+ def transform(self, data: torch.Tensor) -> torch.Tensor:
224
+ return data
225
+
226
+ def inverse_transform(self, normalized_data: torch.Tensor) -> torch.Tensor:
227
+ return normalized_data
@@ -0,0 +1,19 @@
1
+ # Global variables
2
+
3
+ # Node features indices
4
+ PD = 0
5
+ QD = 1
6
+ PG = 2
7
+ QG = 3
8
+ VM = 4
9
+ VA = 5
10
+ PQ = 6
11
+ PV = 7
12
+ REF = 8
13
+
14
+ # Edge features indices
15
+ G = 0
16
+ B = 1
17
+
18
+ FEATURES_IDX = {"PD": PD, "QD": QD, "PG": PG, "QG": QG, "VM": VM, "VA": VA}
19
+ BUS_TYPES = ["PQ", "PV", "REF"]
@@ -0,0 +1,192 @@
1
+ from gridfm_graphkit.datasets.data_normalization import Normalizer, BaseMVANormalizer
2
+ from gridfm_graphkit.datasets.transforms import (
3
+ AddEdgeWeights,
4
+ AddNormalizedRandomWalkPE,
5
+ )
6
+
7
+ import os.path as osp
8
+ import torch
9
+ from torch_geometric.data import Data, InMemoryDataset
10
+ import pandas as pd
11
+ from tqdm import tqdm
12
+ from typing import Optional, Callable
13
+
14
+
15
+ class GridDatasetMem(InMemoryDataset):
16
+ """
17
+ A PyTorch Geometric `InMemoryDataset` for power grid data stored in tabular CSV format.
18
+ This dataset class reads node and edge data from CSV files, applies normalization using
19
+ user-specified `Normalizer` instances, and builds graph data objects with edge weights and
20
+ positional encodings.
21
+
22
+ - Reads raw node and edge CSV files (`pf_node.csv`, `pf_edge.csv`).
23
+ - Applies the normalization method specified on both node and edge features
24
+ - Stores normalization statistics in the `processed` directory for reuse.
25
+ - Constructs `torch_geometric.data.Data` objects with edge weights and positional encodings (via random walk embeddings).
26
+
27
+ Args:
28
+ root (str): Root directory where the dataset is stored.
29
+ norm_method (str): Identifier for normalization method (e.g., "minmax", "standard").
30
+ node_normalizer (Normalizer): Normalizer used for node features.
31
+ edge_normalizer (Normalizer): Normalizer used for edge features.
32
+ pe_dim (int): Length of the random walk used for positional encoding.
33
+ mask_dim (int, optional): Number of features per-node that could be masked. Usually Pd, Qd, Pg, Qg, Vm, Va
34
+ transform (callable, optional): Transformation applied at runtime.
35
+ pre_transform (callable, optional): Transformation applied before saving to disk.
36
+ pre_filter (callable, optional): Filter to determine which graphs to keep.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ root: str,
42
+ norm_method: str,
43
+ node_normalizer: Normalizer,
44
+ edge_normalizer: Normalizer,
45
+ pe_dim: int,
46
+ mask_dim: int = 6,
47
+ transform: Optional[Callable] = None,
48
+ pre_transform: Optional[Callable] = None,
49
+ pre_filter: Optional[Callable] = None,
50
+ ):
51
+ self.norm_method = norm_method
52
+ self.node_normalizer = node_normalizer
53
+ self.edge_normalizer = edge_normalizer
54
+ self.pe_dim = pe_dim
55
+ self.mask_dim = mask_dim
56
+ self.original_transform = None
57
+
58
+ super().__init__(root, transform, pre_transform, pre_filter)
59
+
60
+ node_stats_path = osp.join(
61
+ self.processed_dir,
62
+ f"node_stats_{self.norm_method}.pt",
63
+ )
64
+ edge_stats_path = osp.join(
65
+ self.processed_dir,
66
+ f"edge_stats_{self.norm_method}.pt",
67
+ )
68
+ if osp.exists(node_stats_path) and osp.exists(edge_stats_path):
69
+ self.node_stats = torch.load(node_stats_path, weights_only=False)
70
+ self.edge_stats = torch.load(edge_stats_path, weights_only=False)
71
+ self.node_normalizer.fit_from_dict(self.node_stats)
72
+ self.edge_normalizer.fit_from_dict(self.edge_stats)
73
+ self.load(self.processed_paths[0])
74
+
75
+ @property
76
+ def raw_file_names(self):
77
+ # No raw files needed for random graphs
78
+ return ["pf_node.csv", "pf_edge.csv"]
79
+
80
+ @property
81
+ def processed_file_names(self):
82
+ return [f"data_full_{self.norm_method}.pt"]
83
+
84
+ def download(self):
85
+ pass
86
+
87
+ def process(self):
88
+ node_df = pd.read_csv(osp.join(self.raw_dir, "pf_node.csv"))
89
+ edge_df = pd.read_csv(osp.join(self.raw_dir, "pf_edge.csv"))
90
+
91
+ # Check the unique scenarios available
92
+ scenarios = node_df["scenario"].unique()
93
+ # Ensure node and edge data match
94
+ if not (scenarios == edge_df["scenario"].unique()).all():
95
+ raise ValueError("Mismatch between node and edge scenario values.")
96
+
97
+ # normalize node attributes
98
+ cols_to_normalize = ["Pd", "Qd", "Pg", "Qg", "Vm", "Va"]
99
+ to_normalize = torch.tensor(
100
+ node_df[cols_to_normalize].values,
101
+ dtype=torch.float,
102
+ )
103
+ self.node_stats = self.node_normalizer.fit(to_normalize)
104
+ node_df[cols_to_normalize] = self.node_normalizer.transform(
105
+ to_normalize,
106
+ ).numpy()
107
+
108
+ # normalize edge attributes
109
+ cols_to_normalize = ["G", "B"]
110
+ to_normalize = torch.tensor(
111
+ edge_df[cols_to_normalize].values,
112
+ dtype=torch.float,
113
+ )
114
+ if isinstance(self.node_normalizer, BaseMVANormalizer):
115
+ self.edge_stats = self.edge_normalizer.fit(
116
+ to_normalize,
117
+ self.node_normalizer.baseMVA,
118
+ )
119
+ else:
120
+ self.edge_stats = self.edge_normalizer.fit(to_normalize)
121
+ edge_df[cols_to_normalize] = self.edge_normalizer.transform(
122
+ to_normalize,
123
+ ).numpy()
124
+
125
+ # save stats
126
+ node_stats_path = osp.join(
127
+ self.processed_dir,
128
+ f"node_stats_{self.norm_method}.pt",
129
+ )
130
+ edge_stats_path = osp.join(
131
+ self.processed_dir,
132
+ f"edge_stats_{self.norm_method}.pt",
133
+ )
134
+ torch.save(self.node_stats, node_stats_path)
135
+ torch.save(self.edge_stats, edge_stats_path)
136
+
137
+ # Create groupby objects for scenarios
138
+ node_groups = node_df.groupby("scenario")
139
+ edge_groups = edge_df.groupby("scenario")
140
+
141
+ data_list = []
142
+ for scenario_idx in tqdm(scenarios):
143
+ # NODE DATA
144
+ node_data = node_groups.get_group(scenario_idx)
145
+ x = torch.tensor(
146
+ node_data[
147
+ ["Pd", "Qd", "Pg", "Qg", "Vm", "Va", "PQ", "PV", "REF"]
148
+ ].values,
149
+ dtype=torch.float,
150
+ )
151
+ y = x[:, : self.mask_dim]
152
+
153
+ # EDGE DATA
154
+ edge_data = edge_groups.get_group(scenario_idx)
155
+ edge_attr = torch.tensor(edge_data[["G", "B"]].values, dtype=torch.float)
156
+ edge_index = torch.tensor(
157
+ edge_data[["index1", "index2"]].values.T,
158
+ dtype=torch.long,
159
+ )
160
+
161
+ # Create the Data object
162
+ graph_data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
163
+ pe_pre_transform = AddEdgeWeights()
164
+ graph_data = pe_pre_transform(graph_data)
165
+ pe_transform = AddNormalizedRandomWalkPE(
166
+ walk_length=self.pe_dim,
167
+ attr_name="pe",
168
+ )
169
+ graph_data = pe_transform(graph_data)
170
+ data_list.append(graph_data)
171
+
172
+ self.save(data_list, self.processed_paths[0])
173
+
174
+ def change_transform(self, new_transform):
175
+ """
176
+ Temporarily switch to a new transform function, used when evaluating different tasks.
177
+
178
+ Args:
179
+ new_transform (Callable): The new transform to use.
180
+ """
181
+ self.original_transform = self.transform
182
+ self.transform = new_transform
183
+
184
+ def reset_transform(self):
185
+ """
186
+ Reverts the transform to the original one set during initialization, usually called after the evaluation step.
187
+ """
188
+ if self.original_transform is None:
189
+ raise ValueError(
190
+ "The original transform is None or the function change_transform needs to be called before",
191
+ )
192
+ self.transform = self.original_transform
@@ -0,0 +1,223 @@
1
+ from gridfm_graphkit.datasets.globals import PQ, PV, REF, PG, QG, VM, VA, G, B
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch_geometric.transforms import BaseTransform
6
+ from typing import Optional
7
+ import torch_geometric.typing
8
+ from torch_geometric.data import Data
9
+ from torch_geometric.utils import (
10
+ get_self_loop_attr,
11
+ is_torch_sparse_tensor,
12
+ to_edge_index,
13
+ to_torch_coo_tensor,
14
+ to_torch_csr_tensor,
15
+ )
16
+
17
+
18
+ class AddNormalizedRandomWalkPE(BaseTransform):
19
+ r"""Adds the random walk positional encoding from the
20
+ [Graph Neural Networks with Learnable Structural and Positional Representations](https://arxiv.org/abs/2110.07875)
21
+ paper to the given graph. This is an adaptation from the original Pytorch Geometric implementation.
22
+
23
+ Args:
24
+ walk_length (int): The number of random walk steps.
25
+ attr_name (str, optional): The attribute name of the data object to add
26
+ positional encodings to. If set to :obj:`None`, will be
27
+ concatenated to :obj:`data.x`.
28
+ (default: :obj:`"random_walk_pe"`)
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ walk_length: int,
34
+ attr_name: Optional[str] = "random_walk_pe",
35
+ ) -> None:
36
+ self.walk_length = walk_length
37
+ self.attr_name = attr_name
38
+
39
+ def forward(self, data: Data) -> Data:
40
+ if data.edge_index is None:
41
+ raise ValueError("Expected data.edge_index to be not None")
42
+ row, col = data.edge_index
43
+ N = data.num_nodes
44
+ if N is None:
45
+ raise ValueError("Expected data.num_nodes to be not None")
46
+
47
+ if N <= 2_000: # Dense code path for faster computation:
48
+ adj = torch.zeros((N, N), device=row.device)
49
+ adj[row, col] = data.edge_weight
50
+ loop_index = torch.arange(N, device=row.device)
51
+ elif torch_geometric.typing.WITH_WINDOWS:
52
+ adj = to_torch_coo_tensor(
53
+ data.edge_index,
54
+ data.edge_weight,
55
+ size=data.size(),
56
+ )
57
+ else:
58
+ adj = to_torch_csr_tensor(
59
+ data.edge_index,
60
+ data.edge_weight,
61
+ size=data.size(),
62
+ )
63
+
64
+ row_sums = adj.sum(dim=1, keepdim=True) # Sum along rows
65
+ row_sums = row_sums.clamp(min=1e-8) # Prevent division by zero
66
+
67
+ adj = adj / row_sums # Normalize each row to sum to 1
68
+
69
+ def get_pe(out: Tensor) -> Tensor:
70
+ if is_torch_sparse_tensor(out):
71
+ return get_self_loop_attr(*to_edge_index(out), num_nodes=N)
72
+ return out[loop_index, loop_index]
73
+
74
+ out = adj
75
+ pe_list = [get_pe(out)]
76
+ for _ in range(self.walk_length - 1):
77
+ out = out @ adj
78
+ pe_list.append(get_pe(out))
79
+
80
+ pe = torch.stack(pe_list, dim=-1)
81
+ data[self.attr_name] = pe
82
+
83
+ return data
84
+
85
+
86
+ class AddEdgeWeights(BaseTransform):
87
+ """
88
+ Computes and adds edge weight as the magnitude of complex admittance.
89
+
90
+ The magnitude is computed from the G and B components in `data.edge_attr` and stored in `data.edge_weight`.
91
+ """
92
+
93
+ def forward(self, data):
94
+ if not hasattr(data, "edge_attr"):
95
+ raise AttributeError("Data must have 'edge_attr'.")
96
+
97
+ # Extract real and imaginary parts of admittance
98
+ real = data.edge_attr[:, G]
99
+ imag = data.edge_attr[:, B]
100
+
101
+ # Compute the magnitude of the complex admittance
102
+ edge_weight = torch.sqrt(real**2 + imag**2)
103
+
104
+ # Add the computed edge weights to the data object
105
+ data.edge_weight = edge_weight
106
+
107
+ return data
108
+
109
+
110
+ class AddIdentityMask(BaseTransform):
111
+ """Creates an identity mask, and adds it as a `mask` attribute.
112
+
113
+ The mask is generated such that every entry is False, so no masking is actually applied
114
+ """
115
+
116
+ def forward(self, data):
117
+ if not hasattr(data, "y"):
118
+ raise AttributeError("Data must have ground truth 'y'.")
119
+
120
+ # Generate an identity mask
121
+ mask = torch.zeros_like(data.y, dtype=torch.bool)
122
+
123
+ # Add the mask to the data object
124
+ data.mask = mask
125
+
126
+ return data
127
+
128
+
129
+ class AddRandomMask(BaseTransform):
130
+ """Creates a random mask, and adds it as a `mask` attribute.
131
+
132
+ The mask is generated such that each entry is `True` with probability
133
+ `mask_ratio` and `False` otherwise.
134
+ """
135
+
136
+ def __init__(self, mask_dim, mask_ratio):
137
+ super().__init__()
138
+ self.mask_dim = mask_dim
139
+ self.mask_ratio = mask_ratio
140
+
141
+ def forward(self, data):
142
+ if not hasattr(data, "x"):
143
+ raise AttributeError("Data must have node features 'x'.")
144
+
145
+ # Generate a random mask
146
+ mask = torch.rand(data.x.size(0), self.mask_dim) < self.mask_ratio
147
+
148
+ # Add the mask to the data object
149
+ data.mask = mask
150
+
151
+ return data
152
+
153
+
154
+ class AddPFMask(BaseTransform):
155
+ """Creates a mask according to the power flow problem and assigns it as a `mask` attribute."""
156
+
157
+ def forward(self, data):
158
+ # Ensure the data object has the required attributes
159
+ if not hasattr(data, "y"):
160
+ raise AttributeError("Data must have ground truth 'y'.")
161
+
162
+ if not hasattr(data, "x"):
163
+ raise AttributeError("Data must have node features 'x'.")
164
+
165
+ # Generate masks for each type of node
166
+ mask_PQ = data.x[:, PQ] == 1 # PQ buses
167
+ mask_PV = data.x[:, PV] == 1 # PV buses
168
+ mask_REF = data.x[:, REF] == 1 # Reference buses
169
+
170
+ # Initialize the mask tensor with False values
171
+ mask = torch.zeros_like(data.y, dtype=torch.bool)
172
+
173
+ mask[mask_PQ, VM] = True # Mask Vm for PQ buses
174
+ mask[mask_PQ, VA] = True # Mask Va for PQ buses
175
+
176
+ mask[mask_PV, QG] = True # Mask Qg for PV buses
177
+ mask[mask_PV, VA] = True # Mask Va for PV buses
178
+
179
+ mask[mask_REF, PG] = True # Mask Pg for REF buses
180
+ mask[mask_REF, QG] = True # Mask Qg for REF buses
181
+
182
+ # Attach the mask to the data object
183
+ data.mask = mask
184
+
185
+ return data
186
+
187
+
188
+ class AddOPFMask(BaseTransform):
189
+ """Creates a mask according to the optimal power flow problem and assigns it as a `mask` attribute."""
190
+
191
+ def forward(self, data):
192
+ # Ensure the data object has the required attributes
193
+ if not hasattr(data, "y"):
194
+ raise AttributeError("Data must have ground truth 'y'.")
195
+
196
+ if not hasattr(data, "x"):
197
+ raise AttributeError("Data must have node features 'x'.")
198
+
199
+ # Generate masks for each type of node
200
+ mask_PQ = data.x[:, PQ] == 1 # PQ buses
201
+ mask_PV = data.x[:, PV] == 1 # PV buses
202
+ mask_REF = data.x[:, REF] == 1 # Reference buses
203
+
204
+ # Initialize the mask tensor with False values
205
+ mask = torch.zeros_like(data.y, dtype=torch.bool)
206
+
207
+ mask[mask_PQ, VM] = True # Mask Vm for PQ
208
+ mask[mask_PQ, VA] = True # Mask Va for PQ
209
+
210
+ mask[mask_PV, PG] = True # Mask Pg for PV
211
+ mask[mask_PV, QG] = True # Mask Qg for PV
212
+ mask[mask_PV, VM] = True # Mask Vm for PV
213
+ mask[mask_PV, VA] = True # Mask Va for PV
214
+
215
+ mask[mask_REF, PG] = True # Mask Pg for REF
216
+ mask[mask_REF, QG] = True # Mask Qg for REF
217
+ mask[mask_REF, VM] = True # Mask Vm for REF
218
+ mask[mask_REF, VA] = True # Mask Va for REF
219
+
220
+ # Attach the mask to the data object
221
+ data.mask = mask
222
+
223
+ return data