TopoStateGrid 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,57 @@
1
+ """TopoStateGrid: physically informed graph construction for power-grid ML."""
2
+
3
+ from .builder import (
4
+ EDGE_FEATURE_NAMES,
5
+ NODE_FEATURE_NAMES,
6
+ build_graph,
7
+ build_graph_from_matpower,
8
+ build_graph_from_opfdata_json,
9
+ build_graphs_from_opfdata,
10
+ )
11
+ from .export import export_dataset, load_graphs, save_graphs, save_split_json, write_metadata_csv
12
+ from .labels import attach_labels, attach_stress_proxy_labels
13
+ from .normalizer import FeatureNormalizer
14
+ from .pandapower import build_graph_from_pandapower
15
+ from .parser import (
16
+ ParsedCase,
17
+ discover_opfdata_examples,
18
+ list_local_power_data,
19
+ parse_matpower_case,
20
+ parse_opfdata_sample,
21
+ )
22
+ from .splits import create_lono_split, create_random_split, create_time_based_split
23
+ from .tables import build_graph_from_csv_tables, build_graph_from_tables
24
+ from .temporal import make_temporal_windows
25
+ from .visualization import render_graph_sequence
26
+
27
+ __version__ = "1.1.0"
28
+
29
+ __all__ = [
30
+ "EDGE_FEATURE_NAMES",
31
+ "NODE_FEATURE_NAMES",
32
+ "FeatureNormalizer",
33
+ "ParsedCase",
34
+ "attach_labels",
35
+ "attach_stress_proxy_labels",
36
+ "build_graph",
37
+ "build_graph_from_matpower",
38
+ "build_graph_from_opfdata_json",
39
+ "build_graph_from_pandapower",
40
+ "build_graph_from_csv_tables",
41
+ "build_graph_from_tables",
42
+ "build_graphs_from_opfdata",
43
+ "create_lono_split",
44
+ "create_random_split",
45
+ "create_time_based_split",
46
+ "discover_opfdata_examples",
47
+ "export_dataset",
48
+ "list_local_power_data",
49
+ "load_graphs",
50
+ "make_temporal_windows",
51
+ "parse_matpower_case",
52
+ "parse_opfdata_sample",
53
+ "render_graph_sequence",
54
+ "save_graphs",
55
+ "save_split_json",
56
+ "write_metadata_csv",
57
+ ]
@@ -0,0 +1,379 @@
1
+ """Homogeneous bus-branch graph construction."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torch_geometric.data import Data
12
+
13
+ from .parser import ParsedCase, discover_opfdata_examples, parse_matpower_case, parse_opfdata_sample
14
+
15
+ NODE_FEATURE_NAMES = [
16
+ "bus_status",
17
+ "bus_type",
18
+ "pd",
19
+ "qd",
20
+ "vm",
21
+ "va",
22
+ "vmax",
23
+ "vmin",
24
+ "normalized_demand",
25
+ ]
26
+
27
+ EDGE_FEATURE_NAMES = [
28
+ "component_type",
29
+ "r",
30
+ "x",
31
+ "b_from",
32
+ "b_to",
33
+ "rate_a",
34
+ "pf",
35
+ "qf",
36
+ "pt",
37
+ "qt",
38
+ "loading_ratio",
39
+ "outage_flag",
40
+ ]
41
+
42
+
43
+ def build_graph(parsed: ParsedCase, attach_proxy_label: bool = False) -> Data:
44
+ """Build a PyTorch Geometric Data object from a parsed case."""
45
+
46
+ if parsed.source_type == "opfdata":
47
+ data = _build_opfdata_graph(parsed)
48
+ elif parsed.source_type == "matpower":
49
+ data = _build_matpower_graph(parsed)
50
+ else:
51
+ raise ValueError(f"Unsupported parsed case type: {parsed.source_type}")
52
+
53
+ if attach_proxy_label:
54
+ from .labels import attach_stress_proxy_labels
55
+
56
+ attach_stress_proxy_labels(data)
57
+ return data
58
+
59
+
60
+ def build_graph_from_opfdata_json(path: str | Path, attach_proxy_label: bool = True) -> Data:
61
+ """Parse one OPFData JSON sample and build one graph."""
62
+
63
+ return build_graph(parse_opfdata_sample(path), attach_proxy_label=attach_proxy_label)
64
+
65
+
66
+ def build_graphs_from_opfdata(
67
+ root: str | Path = "data/opfdata",
68
+ network_id: str | None = None,
69
+ limit: int | None = None,
70
+ attach_proxy_label: bool = True,
71
+ ) -> list[Data]:
72
+ """Build multiple graph samples from extracted OPFData JSON scenarios."""
73
+
74
+ paths = discover_opfdata_examples(root=root, network_id=network_id, limit=limit)
75
+ return [build_graph_from_opfdata_json(path, attach_proxy_label=attach_proxy_label) for path in paths]
76
+
77
+
78
+ def build_graph_from_matpower(path: str | Path, attach_proxy_label: bool = False) -> Data:
79
+ """Build a static graph from a MATPOWER/PGLib `.m` case file."""
80
+
81
+ return build_graph(parse_matpower_case(path), attach_proxy_label=attach_proxy_label)
82
+
83
+
84
+ def _build_opfdata_graph(parsed: ParsedCase) -> Data:
85
+ grid_nodes = parsed.grid.get("nodes") or {}
86
+ grid_edges = parsed.grid.get("edges") or {}
87
+ solution_nodes = parsed.solution.get("nodes") or {}
88
+ solution_edges = parsed.solution.get("edges") or {}
89
+
90
+ bus = _as_2d_array(grid_nodes.get("bus"), width=4)
91
+ n_bus = bus.shape[0]
92
+ if n_bus == 0:
93
+ raise ValueError(f"OPFData sample has no bus nodes: {parsed.path}")
94
+
95
+ bus_status = _column(bus, 0, default=1.0)
96
+ bus_type = _column(bus, 1, default=0.0)
97
+ vmin = _column(bus, 2, default=0.0)
98
+ vmax = _column(bus, 3, default=0.0)
99
+
100
+ sol_bus = _as_2d_array(solution_nodes.get("bus"), width=2, rows=n_bus)
101
+ va = _column(sol_bus, 0, rows=n_bus)
102
+ vm = _column(sol_bus, 1, rows=n_bus)
103
+
104
+ pd, qd = _aggregate_component_to_bus(
105
+ grid_nodes.get("load"),
106
+ grid_edges.get("load_link"),
107
+ n_bus,
108
+ feature_indices=(0, 1),
109
+ )
110
+ demand_norm = _normalized_abs(pd)
111
+
112
+ x = np.column_stack([bus_status, bus_type, pd, qd, vm, va, vmax, vmin, demand_norm])
113
+ edge_index, edge_attr = _opfdata_edges(grid_edges, solution_edges, n_bus)
114
+
115
+ return _make_data(
116
+ x=x,
117
+ edge_index=edge_index,
118
+ edge_attr=edge_attr,
119
+ parsed=parsed,
120
+ notes="OPFData JSON scenario; load values and solved states are scenario-dependent.",
121
+ )
122
+
123
+
124
+ def _build_matpower_graph(parsed: ParsedCase) -> Data:
125
+ bus = _as_2d_array(parsed.grid.get("bus"), width=13)
126
+ branch = _as_2d_array(parsed.grid.get("branch"), width=13)
127
+ n_bus = bus.shape[0]
128
+ if n_bus == 0:
129
+ raise ValueError(f"MATPOWER case has no bus table: {parsed.path}")
130
+
131
+ base = parsed.base_mva if parsed.base_mva else 100.0
132
+ bus_ids = bus[:, 0].astype(int)
133
+ bus_to_idx = {bus_id: idx for idx, bus_id in enumerate(bus_ids)}
134
+
135
+ pd = _column(bus, 2, rows=n_bus) / base
136
+ qd = _column(bus, 3, rows=n_bus) / base
137
+ vm = _column(bus, 7, rows=n_bus)
138
+ va = np.deg2rad(_column(bus, 8, rows=n_bus))
139
+ vmax = _column(bus, 11, rows=n_bus)
140
+ vmin = _column(bus, 12, rows=n_bus)
141
+ demand_norm = _normalized_abs(pd)
142
+ x = np.column_stack([np.ones(n_bus), _column(bus, 1, rows=n_bus), pd, qd, vm, va, vmax, vmin, demand_norm])
143
+
144
+ edges: list[list[int]] = []
145
+ attrs: list[list[float]] = []
146
+ for row in branch:
147
+ f_bus = int(row[0])
148
+ t_bus = int(row[1])
149
+ if f_bus not in bus_to_idx or t_bus not in bus_to_idx:
150
+ continue
151
+ src = bus_to_idx[f_bus]
152
+ dst = bus_to_idx[t_bus]
153
+ r = _value(row, 2)
154
+ x_val = _value(row, 3)
155
+ b = _value(row, 4)
156
+ rate = _value(row, 5) / base if _value(row, 5) else 0.0
157
+ ratio = _value(row, 8)
158
+ shift = _value(row, 9)
159
+ status = _value(row, 10, default=1.0)
160
+ outage = 0.0 if status > 0 else 1.0
161
+ component_type = 1.0 if ratio not in (0.0, 1.0) or shift != 0.0 else 0.0
162
+ attr = [component_type, r, x_val, b, b, rate, 0.0, 0.0, 0.0, 0.0, 0.0, outage]
163
+ _append_bidirectional(edges, attrs, src, dst, attr)
164
+
165
+ edge_index = np.asarray(edges, dtype=np.int64).T if edges else np.empty((2, 0), dtype=np.int64)
166
+ edge_attr = np.asarray(attrs, dtype=np.float32) if attrs else np.empty((0, len(EDGE_FEATURE_NAMES)), dtype=np.float32)
167
+
168
+ return _make_data(
169
+ x=x,
170
+ edge_index=edge_index,
171
+ edge_attr=edge_attr,
172
+ parsed=parsed,
173
+ notes="Static MATPOWER/PGLib case; solved operating-state flow fields are unavailable.",
174
+ )
175
+
176
+
177
+ def _opfdata_edges(grid_edges: dict[str, Any], solution_edges: dict[str, Any], n_bus: int) -> tuple[np.ndarray, np.ndarray]:
178
+ edges: list[list[int]] = []
179
+ attrs: list[list[float]] = []
180
+
181
+ for edge_kind in ("ac_line", "transformer"):
182
+ static = grid_edges.get(edge_kind) or {}
183
+ solution = solution_edges.get(edge_kind) or {}
184
+ senders = static.get("senders") or []
185
+ receivers = static.get("receivers") or []
186
+ static_features = static.get("features") or []
187
+ solution_features = solution.get("features") or []
188
+
189
+ for idx, (src, dst) in enumerate(zip(senders, receivers)):
190
+ if src >= n_bus or dst >= n_bus:
191
+ continue
192
+ static_feature = static_features[idx] if idx < len(static_features) else []
193
+ solution_feature = solution_features[idx] if idx < len(solution_features) else []
194
+ attr = _opfdata_edge_attr(edge_kind, static_feature, solution_feature)
195
+ reverse = attr.copy()
196
+ reverse[3], reverse[4] = attr[4], attr[3]
197
+ reverse[6], reverse[7], reverse[8], reverse[9] = attr[8], attr[9], attr[6], attr[7]
198
+ edges.append([int(src), int(dst)])
199
+ attrs.append(attr)
200
+ edges.append([int(dst), int(src)])
201
+ attrs.append(reverse)
202
+
203
+ edge_index = np.asarray(edges, dtype=np.int64).T if edges else np.empty((2, 0), dtype=np.int64)
204
+ edge_attr = np.asarray(attrs, dtype=np.float32) if attrs else np.empty((0, len(EDGE_FEATURE_NAMES)), dtype=np.float32)
205
+ return edge_index, edge_attr
206
+
207
+
208
+ def _opfdata_edge_attr(edge_kind: str, static_feature: list[float], solution_feature: list[float]) -> list[float]:
209
+ if edge_kind == "ac_line":
210
+ component_type = 0.0
211
+ b_from = _value(static_feature, 2)
212
+ b_to = _value(static_feature, 3)
213
+ r = _value(static_feature, 4)
214
+ x_val = _value(static_feature, 5)
215
+ rate = _value(static_feature, 6)
216
+ else:
217
+ component_type = 1.0
218
+ b_from = 0.0
219
+ b_to = 0.0
220
+ r = _value(static_feature, 2)
221
+ x_val = _value(static_feature, 3)
222
+ rate = _value(static_feature, 4)
223
+
224
+ pf = _value(solution_feature, 0)
225
+ qf = _value(solution_feature, 1)
226
+ pt = _value(solution_feature, 2)
227
+ qt = _value(solution_feature, 3)
228
+ loading = _loading_ratio(pf, qf, pt, qt, rate)
229
+ return [component_type, r, x_val, b_from, b_to, rate, pf, qf, pt, qt, loading, 0.0]
230
+
231
+
232
+ def _aggregate_component_to_bus(
233
+ features: Any,
234
+ link: dict[str, Any] | None,
235
+ n_bus: int,
236
+ feature_indices: tuple[int, int],
237
+ ) -> tuple[np.ndarray, np.ndarray]:
238
+ out_a = np.zeros(n_bus, dtype=float)
239
+ out_b = np.zeros(n_bus, dtype=float)
240
+ component_features = _as_2d_array(features, width=max(feature_indices) + 1)
241
+ if component_features.size == 0 or not link:
242
+ return out_a, out_b
243
+
244
+ pairs = _component_to_bus_pairs(
245
+ link.get("senders") or [],
246
+ link.get("receivers") or [],
247
+ component_count=component_features.shape[0],
248
+ bus_count=n_bus,
249
+ )
250
+ for component_idx, bus_idx in pairs:
251
+ out_a[bus_idx] += _value(component_features[component_idx], feature_indices[0])
252
+ out_b[bus_idx] += _value(component_features[component_idx], feature_indices[1])
253
+ return out_a, out_b
254
+
255
+
256
+ def _component_to_bus_pairs(
257
+ senders: list[int],
258
+ receivers: list[int],
259
+ component_count: int,
260
+ bus_count: int,
261
+ ) -> list[tuple[int, int]]:
262
+ pairs: list[tuple[int, int]] = []
263
+ for sender, receiver in zip(senders, receivers):
264
+ if 0 <= sender < component_count and 0 <= receiver < bus_count:
265
+ pairs.append((int(sender), int(receiver)))
266
+ elif 0 <= receiver < component_count and 0 <= sender < bus_count:
267
+ pairs.append((int(receiver), int(sender)))
268
+ return pairs
269
+
270
+
271
+ def _make_data(
272
+ x: np.ndarray,
273
+ edge_index: np.ndarray,
274
+ edge_attr: np.ndarray,
275
+ parsed: ParsedCase,
276
+ notes: str,
277
+ ) -> Data:
278
+ data = Data(
279
+ x=torch.as_tensor(np.nan_to_num(x, nan=0.0), dtype=torch.float32),
280
+ edge_index=torch.as_tensor(edge_index, dtype=torch.long),
281
+ edge_attr=torch.as_tensor(np.nan_to_num(edge_attr, nan=0.0), dtype=torch.float32),
282
+ )
283
+ data.y = torch.tensor([-1], dtype=torch.long)
284
+ data.y_cls = torch.tensor([-1], dtype=torch.long)
285
+ data.y_reg = torch.tensor([0.0], dtype=torch.float32)
286
+ data.risk_score = torch.tensor([0.0], dtype=torch.float32)
287
+ data.has_label = torch.tensor([False], dtype=torch.bool)
288
+ data.label_state = "missing"
289
+ data.num_nodes = int(x.shape[0])
290
+ data.network_id = parsed.network_id
291
+ data.sample_id = parsed.sample_id
292
+ data.timestamp = parsed.timestamp if parsed.timestamp is not None else ""
293
+ data.scenario_id = parsed.scenario_id if parsed.scenario_id is not None else ""
294
+ data.contingency_id = parsed.contingency_id if parsed.contingency_id is not None else ""
295
+ data.base_mva = float(parsed.base_mva)
296
+ data.source_type = parsed.source_type
297
+ data.source_format = parsed.source_type
298
+ data.source_path = parsed.path if parsed.path is not None else ""
299
+ data.node_feature_names = list(NODE_FEATURE_NAMES)
300
+ data.edge_feature_names = list(EDGE_FEATURE_NAMES)
301
+ data.metadata_json = _metadata_to_json(parsed.metadata)
302
+ data.metadata = data.metadata_json
303
+ data.construction_notes = notes
304
+ data.label_notes = ""
305
+ return data
306
+
307
+
308
+ def _append_bidirectional(
309
+ edges: list[list[int]],
310
+ attrs: list[list[float]],
311
+ src: int,
312
+ dst: int,
313
+ attr: list[float],
314
+ ) -> None:
315
+ edges.append([int(src), int(dst)])
316
+ attrs.append(attr)
317
+ reverse = attr.copy()
318
+ reverse[3], reverse[4] = attr[4], attr[3]
319
+ reverse[6], reverse[7], reverse[8], reverse[9] = attr[8], attr[9], attr[6], attr[7]
320
+ edges.append([int(dst), int(src)])
321
+ attrs.append(reverse)
322
+
323
+
324
+ def _loading_ratio(pf: float, qf: float, pt: float, qt: float, rate: float) -> float:
325
+ if rate <= 0:
326
+ return 0.0
327
+ s_from = float(np.hypot(pf, qf))
328
+ s_to = float(np.hypot(pt, qt))
329
+ return max(s_from, s_to) / rate
330
+
331
+
332
+ def _normalized_abs(values: np.ndarray) -> np.ndarray:
333
+ scale = float(np.nanmax(np.abs(values))) if values.size else 0.0
334
+ if scale <= 0:
335
+ return np.zeros_like(values, dtype=float)
336
+ return values / scale
337
+
338
+
339
+ def _as_2d_array(values: Any, width: int, rows: int | None = None) -> np.ndarray:
340
+ if values is None:
341
+ row_count = 0 if rows is None else rows
342
+ return np.zeros((row_count, width), dtype=float)
343
+ arr = np.asarray(values, dtype=float)
344
+ if arr.size == 0:
345
+ row_count = 0 if rows is None else rows
346
+ return np.zeros((row_count, width), dtype=float)
347
+ if arr.ndim == 1:
348
+ arr = arr.reshape(1, -1)
349
+ if rows is not None and arr.shape[0] < rows:
350
+ padded = np.zeros((rows, max(width, arr.shape[1])), dtype=float)
351
+ padded[: arr.shape[0], : arr.shape[1]] = arr
352
+ arr = padded
353
+ return arr
354
+
355
+
356
+ def _column(arr: np.ndarray, idx: int, default: float = 0.0, rows: int | None = None) -> np.ndarray:
357
+ row_count = rows if rows is not None else arr.shape[0]
358
+ if arr.size == 0 or idx >= arr.shape[1]:
359
+ return np.full(row_count, default, dtype=float)
360
+ col = arr[:row_count, idx].astype(float)
361
+ if len(col) < row_count:
362
+ padded = np.full(row_count, default, dtype=float)
363
+ padded[: len(col)] = col
364
+ return padded
365
+ return col
366
+
367
+
368
+ def _value(row: Any, idx: int, default: float = 0.0) -> float:
369
+ try:
370
+ if idx >= len(row):
371
+ return default
372
+ value = float(row[idx])
373
+ return value if np.isfinite(value) else default
374
+ except (TypeError, ValueError):
375
+ return default
376
+
377
+
378
+ def _metadata_to_json(metadata: dict[str, Any]) -> str:
379
+ return json.dumps(dict(metadata), sort_keys=True, default=str)
@@ -0,0 +1,142 @@
1
+ """Disk export helpers for generated graph datasets."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import csv
6
+ import json
7
+ from pathlib import Path
8
+ from typing import Any, Sequence
9
+
10
+ import torch
11
+
12
+
13
+ def save_graphs(graphs: Sequence[Any], path: str | Path) -> Path:
14
+ """Save graphs with torch.save."""
15
+
16
+ path = Path(path)
17
+ path.parent.mkdir(parents=True, exist_ok=True)
18
+ torch.save(list(graphs), path)
19
+ return path
20
+
21
+
22
+ def load_graphs(path: str | Path) -> list[Any]:
23
+ """Load graphs saved by `save_graphs` across PyTorch default changes."""
24
+
25
+ path = Path(path)
26
+ try:
27
+ return torch.load(path, weights_only=False)
28
+ except TypeError:
29
+ return torch.load(path)
30
+
31
+
32
+ def write_metadata_csv(graphs: Sequence[Any], path: str | Path) -> Path:
33
+ """Write a compact metadata table for graph samples."""
34
+
35
+ path = Path(path)
36
+ path.parent.mkdir(parents=True, exist_ok=True)
37
+ fields = [
38
+ "index",
39
+ "network_id",
40
+ "sample_id",
41
+ "timestamp",
42
+ "scenario_id",
43
+ "contingency_id",
44
+ "source_type",
45
+ "num_nodes",
46
+ "num_edges",
47
+ "num_node_features",
48
+ "num_edge_features",
49
+ "risk_score",
50
+ "y",
51
+ ]
52
+ with path.open("w", newline="", encoding="utf-8") as handle:
53
+ writer = csv.DictWriter(handle, fieldnames=fields)
54
+ writer.writeheader()
55
+ for idx, graph in enumerate(graphs):
56
+ has_label = _has_label(graph)
57
+ writer.writerow(
58
+ {
59
+ "index": idx,
60
+ "network_id": getattr(graph, "network_id", ""),
61
+ "sample_id": getattr(graph, "sample_id", ""),
62
+ "timestamp": getattr(graph, "timestamp", ""),
63
+ "scenario_id": getattr(graph, "scenario_id", ""),
64
+ "contingency_id": getattr(graph, "contingency_id", ""),
65
+ "source_type": getattr(graph, "source_type", ""),
66
+ "num_nodes": getattr(graph, "num_nodes", graph.x.shape[0] if hasattr(graph, "x") else ""),
67
+ "num_edges": graph.edge_index.shape[1] if hasattr(graph, "edge_index") else "",
68
+ "num_node_features": graph.x.shape[1] if hasattr(graph, "x") and graph.x.ndim == 2 else "",
69
+ "num_edge_features": graph.edge_attr.shape[1]
70
+ if hasattr(graph, "edge_attr") and graph.edge_attr.ndim == 2
71
+ else "",
72
+ "risk_score": _first_scalar(getattr(graph, "risk_score", "")) if has_label else "",
73
+ "y": _first_scalar(getattr(graph, "y", "")) if has_label else "",
74
+ }
75
+ )
76
+ return path
77
+
78
+
79
+ def save_split_json(split: dict[str, Any], path: str | Path) -> Path:
80
+ """Save split indices as JSON."""
81
+
82
+ path = Path(path)
83
+ path.parent.mkdir(parents=True, exist_ok=True)
84
+ with path.open("w", encoding="utf-8") as handle:
85
+ json.dump(split, handle, indent=2)
86
+ return path
87
+
88
+
89
+ def export_dataset(
90
+ graphs: Sequence[Any],
91
+ output_dir: str | Path = "outputs",
92
+ split: dict[str, Any] | None = None,
93
+ ) -> dict[str, Path]:
94
+ """Write graphs, metadata, optional split, and a generated README."""
95
+
96
+ output_dir = Path(output_dir)
97
+ output_dir.mkdir(parents=True, exist_ok=True)
98
+ paths = {
99
+ "graphs": save_graphs(graphs, output_dir / "graphs.pt"),
100
+ "metadata": write_metadata_csv(graphs, output_dir / "metadata.csv"),
101
+ "readme": _write_generated_readme(graphs, output_dir / "README_generated.md"),
102
+ }
103
+ if split is not None:
104
+ paths["split"] = save_split_json(split, output_dir / "split.json")
105
+ return paths
106
+
107
+
108
+ def _write_generated_readme(graphs: Sequence[Any], path: Path) -> Path:
109
+ networks = sorted({str(getattr(graph, "network_id", "")) for graph in graphs})
110
+ path.write_text(
111
+ "\n".join(
112
+ [
113
+ "# TopoStateGrid Generated Dataset",
114
+ "",
115
+ f"Graphs: {len(graphs)}",
116
+ f"Networks: {', '.join(networks) if networks else 'unknown'}",
117
+ "",
118
+ "Each graph is a PyTorch Geometric Data object with bus nodes, bidirectional branch edges,",
119
+ "node features, edge features, preserved metadata, and optional proxy labels.",
120
+ "",
121
+ ]
122
+ ),
123
+ encoding="utf-8",
124
+ )
125
+ return path
126
+
127
+
128
+ def _first_scalar(value: Any) -> Any:
129
+ if isinstance(value, torch.Tensor):
130
+ if value.numel() == 0:
131
+ return ""
132
+ return value.detach().cpu().reshape(-1)[0].item()
133
+ return value
134
+
135
+
136
+ def _has_label(graph: Any) -> bool:
137
+ value = getattr(graph, "has_label", False)
138
+ if isinstance(value, torch.Tensor):
139
+ if value.numel() == 0:
140
+ return False
141
+ return bool(value.detach().cpu().reshape(-1)[0].item())
142
+ return bool(value)
@@ -0,0 +1,116 @@
1
+ """Label attachment utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ import torch
8
+ from torch_geometric.data import Data
9
+
10
+
11
+ def attach_labels(
12
+ data: Data,
13
+ y: Any | None = None,
14
+ y_cls: Any | None = None,
15
+ y_reg: Any | None = None,
16
+ risk_score: Any | None = None,
17
+ ) -> Data:
18
+ """Attach user-provided labels to a graph in-place and return it."""
19
+
20
+ if y is not None:
21
+ data.y = _tensor_1d(y)
22
+ data.has_label = torch.tensor([True], dtype=torch.bool)
23
+ data.label_state = "user"
24
+ if y_cls is not None:
25
+ data.y_cls = _tensor_1d(y_cls, dtype=torch.long)
26
+ if y is None:
27
+ data.y = data.y_cls
28
+ data.has_label = torch.tensor([True], dtype=torch.bool)
29
+ data.label_state = "user"
30
+ if y_reg is not None:
31
+ data.y_reg = _tensor_1d(y_reg)
32
+ if y is None and y_cls is None:
33
+ data.y = data.y_reg
34
+ data.has_label = torch.tensor([True], dtype=torch.bool)
35
+ data.label_state = "user"
36
+ if risk_score is not None:
37
+ data.risk_score = _tensor_1d(risk_score)
38
+ data.has_label = torch.tensor([True], dtype=torch.bool)
39
+ data.label_state = "user"
40
+ return data
41
+
42
+
43
+ def attach_stress_proxy_labels(
44
+ data: Data,
45
+ threshold: float = 1.0,
46
+ loading_feature: str = "loading_ratio",
47
+ overwrite: bool = False,
48
+ ) -> Data:
49
+ """Attach temporary stress labels derived from max branch loading.
50
+
51
+ This is a proxy label for prototyping graph construction. It is not a real
52
+ cascading-failure or reliability target. Existing labels are preserved by
53
+ default; pass ``overwrite=True`` to replace them intentionally.
54
+ """
55
+
56
+ existing = _existing_label_fields(data)
57
+ if existing and not overwrite:
58
+ raise ValueError(
59
+ "Proxy label attachment would overwrite existing label fields "
60
+ f"{existing}; pass overwrite=True to replace them."
61
+ )
62
+
63
+ feature_names = getattr(data, "edge_feature_names", [])
64
+ if loading_feature not in feature_names or data.edge_attr.numel() == 0:
65
+ risk = 0.0
66
+ else:
67
+ idx = feature_names.index(loading_feature)
68
+ values = data.edge_attr[:, idx]
69
+ finite_values = values[torch.isfinite(values)]
70
+ risk = float(finite_values.max().item()) if finite_values.numel() else 0.0
71
+
72
+ data.risk_score = torch.tensor([risk], dtype=torch.float32)
73
+ data.y_reg = torch.tensor([risk], dtype=torch.float32)
74
+ data.y_cls = torch.tensor([1 if risk > threshold else 0], dtype=torch.long)
75
+ data.y = data.y_cls
76
+ data.has_label = torch.tensor([True], dtype=torch.bool)
77
+ data.label_state = "proxy"
78
+ data.label_notes = (
79
+ "Temporary proxy: y_cls = 1 if max loading_ratio exceeds the threshold; "
80
+ "not a cascading-failure ground-truth label."
81
+ )
82
+ return data
83
+
84
+
85
+ def _tensor_1d(value: Any, dtype: torch.dtype = torch.float32) -> torch.Tensor:
86
+ tensor = torch.as_tensor(value, dtype=dtype)
87
+ if tensor.ndim == 0:
88
+ tensor = tensor.reshape(1)
89
+ return tensor
90
+
91
+
92
+ def _existing_label_fields(data: Data) -> list[str]:
93
+ existing: list[str] = []
94
+ for name in ("y", "y_cls", "y_reg", "risk_score"):
95
+ if hasattr(data, name) and getattr(data, name) is not None:
96
+ if _is_missing_label_placeholder(data, name):
97
+ continue
98
+ existing.append(name)
99
+ return existing
100
+
101
+
102
+ def _is_missing_label_placeholder(data: Data, name: str) -> bool:
103
+ if not hasattr(data, "has_label"):
104
+ return False
105
+ has_label = getattr(data, "has_label")
106
+ if isinstance(has_label, torch.Tensor) and bool(has_label.reshape(-1)[0].item()):
107
+ return False
108
+ if getattr(data, "label_state", "") == "missing":
109
+ return True
110
+ value = getattr(data, name)
111
+ if not isinstance(value, torch.Tensor) or value.numel() != 1:
112
+ return False
113
+ scalar = value.reshape(-1)[0]
114
+ if name in {"y", "y_cls"}:
115
+ return int(scalar.item()) == -1
116
+ return bool(torch.isnan(scalar))