morphml 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of morphml might be problematic. Click here for more details.

Files changed (158) hide show
  1. morphml/__init__.py +14 -0
  2. morphml/api/__init__.py +26 -0
  3. morphml/api/app.py +326 -0
  4. morphml/api/auth.py +193 -0
  5. morphml/api/client.py +338 -0
  6. morphml/api/models.py +132 -0
  7. morphml/api/rate_limit.py +192 -0
  8. morphml/benchmarking/__init__.py +36 -0
  9. morphml/benchmarking/comparison.py +430 -0
  10. morphml/benchmarks/__init__.py +56 -0
  11. morphml/benchmarks/comparator.py +409 -0
  12. morphml/benchmarks/datasets.py +280 -0
  13. morphml/benchmarks/metrics.py +199 -0
  14. morphml/benchmarks/openml_suite.py +201 -0
  15. morphml/benchmarks/problems.py +289 -0
  16. morphml/benchmarks/suite.py +318 -0
  17. morphml/cli/__init__.py +5 -0
  18. morphml/cli/commands/experiment.py +329 -0
  19. morphml/cli/main.py +457 -0
  20. morphml/cli/quickstart.py +312 -0
  21. morphml/config.py +278 -0
  22. morphml/constraints/__init__.py +19 -0
  23. morphml/constraints/handler.py +205 -0
  24. morphml/constraints/predicates.py +285 -0
  25. morphml/core/__init__.py +3 -0
  26. morphml/core/crossover.py +449 -0
  27. morphml/core/dsl/README.md +359 -0
  28. morphml/core/dsl/__init__.py +72 -0
  29. morphml/core/dsl/ast_nodes.py +364 -0
  30. morphml/core/dsl/compiler.py +318 -0
  31. morphml/core/dsl/layers.py +368 -0
  32. morphml/core/dsl/lexer.py +336 -0
  33. morphml/core/dsl/parser.py +455 -0
  34. morphml/core/dsl/search_space.py +386 -0
  35. morphml/core/dsl/syntax.py +199 -0
  36. morphml/core/dsl/type_system.py +361 -0
  37. morphml/core/dsl/validator.py +386 -0
  38. morphml/core/graph/__init__.py +40 -0
  39. morphml/core/graph/edge.py +124 -0
  40. morphml/core/graph/graph.py +507 -0
  41. morphml/core/graph/mutations.py +409 -0
  42. morphml/core/graph/node.py +196 -0
  43. morphml/core/graph/serialization.py +361 -0
  44. morphml/core/graph/visualization.py +431 -0
  45. morphml/core/objectives/__init__.py +20 -0
  46. morphml/core/search/__init__.py +33 -0
  47. morphml/core/search/individual.py +252 -0
  48. morphml/core/search/parameters.py +453 -0
  49. morphml/core/search/population.py +375 -0
  50. morphml/core/search/search_engine.py +340 -0
  51. morphml/distributed/__init__.py +76 -0
  52. morphml/distributed/fault_tolerance.py +497 -0
  53. morphml/distributed/health_monitor.py +348 -0
  54. morphml/distributed/master.py +709 -0
  55. morphml/distributed/proto/README.md +224 -0
  56. morphml/distributed/proto/__init__.py +74 -0
  57. morphml/distributed/proto/worker.proto +170 -0
  58. morphml/distributed/proto/worker_pb2.py +79 -0
  59. morphml/distributed/proto/worker_pb2_grpc.py +423 -0
  60. morphml/distributed/resource_manager.py +416 -0
  61. morphml/distributed/scheduler.py +567 -0
  62. morphml/distributed/storage/__init__.py +33 -0
  63. morphml/distributed/storage/artifacts.py +381 -0
  64. morphml/distributed/storage/cache.py +366 -0
  65. morphml/distributed/storage/checkpointing.py +329 -0
  66. morphml/distributed/storage/database.py +459 -0
  67. morphml/distributed/worker.py +549 -0
  68. morphml/evaluation/__init__.py +5 -0
  69. morphml/evaluation/heuristic.py +237 -0
  70. morphml/exceptions.py +55 -0
  71. morphml/execution/__init__.py +5 -0
  72. morphml/execution/local_executor.py +350 -0
  73. morphml/integrations/__init__.py +28 -0
  74. morphml/integrations/jax_adapter.py +206 -0
  75. morphml/integrations/pytorch_adapter.py +530 -0
  76. morphml/integrations/sklearn_adapter.py +206 -0
  77. morphml/integrations/tensorflow_adapter.py +230 -0
  78. morphml/logging_config.py +93 -0
  79. morphml/meta_learning/__init__.py +66 -0
  80. morphml/meta_learning/architecture_similarity.py +277 -0
  81. morphml/meta_learning/experiment_database.py +240 -0
  82. morphml/meta_learning/knowledge_base/__init__.py +19 -0
  83. morphml/meta_learning/knowledge_base/embedder.py +179 -0
  84. morphml/meta_learning/knowledge_base/knowledge_base.py +313 -0
  85. morphml/meta_learning/knowledge_base/meta_features.py +265 -0
  86. morphml/meta_learning/knowledge_base/vector_store.py +271 -0
  87. morphml/meta_learning/predictors/__init__.py +27 -0
  88. morphml/meta_learning/predictors/ensemble.py +221 -0
  89. morphml/meta_learning/predictors/gnn_predictor.py +552 -0
  90. morphml/meta_learning/predictors/learning_curve.py +231 -0
  91. morphml/meta_learning/predictors/proxy_metrics.py +261 -0
  92. morphml/meta_learning/strategy_evolution/__init__.py +27 -0
  93. morphml/meta_learning/strategy_evolution/adaptive_optimizer.py +226 -0
  94. morphml/meta_learning/strategy_evolution/bandit.py +276 -0
  95. morphml/meta_learning/strategy_evolution/portfolio.py +230 -0
  96. morphml/meta_learning/transfer.py +581 -0
  97. morphml/meta_learning/warm_start.py +286 -0
  98. morphml/optimizers/__init__.py +74 -0
  99. morphml/optimizers/adaptive_operators.py +399 -0
  100. morphml/optimizers/bayesian/__init__.py +52 -0
  101. morphml/optimizers/bayesian/acquisition.py +387 -0
  102. morphml/optimizers/bayesian/base.py +319 -0
  103. morphml/optimizers/bayesian/gaussian_process.py +635 -0
  104. morphml/optimizers/bayesian/smac.py +534 -0
  105. morphml/optimizers/bayesian/tpe.py +411 -0
  106. morphml/optimizers/differential_evolution.py +220 -0
  107. morphml/optimizers/evolutionary/__init__.py +61 -0
  108. morphml/optimizers/evolutionary/cma_es.py +416 -0
  109. morphml/optimizers/evolutionary/differential_evolution.py +556 -0
  110. morphml/optimizers/evolutionary/encoding.py +426 -0
  111. morphml/optimizers/evolutionary/particle_swarm.py +449 -0
  112. morphml/optimizers/genetic_algorithm.py +486 -0
  113. morphml/optimizers/gradient_based/__init__.py +22 -0
  114. morphml/optimizers/gradient_based/darts.py +550 -0
  115. morphml/optimizers/gradient_based/enas.py +585 -0
  116. morphml/optimizers/gradient_based/operations.py +474 -0
  117. morphml/optimizers/gradient_based/utils.py +601 -0
  118. morphml/optimizers/hill_climbing.py +169 -0
  119. morphml/optimizers/multi_objective/__init__.py +56 -0
  120. morphml/optimizers/multi_objective/indicators.py +504 -0
  121. morphml/optimizers/multi_objective/nsga2.py +647 -0
  122. morphml/optimizers/multi_objective/visualization.py +427 -0
  123. morphml/optimizers/nsga2.py +308 -0
  124. morphml/optimizers/random_search.py +172 -0
  125. morphml/optimizers/simulated_annealing.py +181 -0
  126. morphml/plugins/__init__.py +35 -0
  127. morphml/plugins/custom_evaluator_example.py +81 -0
  128. morphml/plugins/custom_optimizer_example.py +63 -0
  129. morphml/plugins/plugin_system.py +454 -0
  130. morphml/reports/__init__.py +30 -0
  131. morphml/reports/generator.py +362 -0
  132. morphml/tracking/__init__.py +7 -0
  133. morphml/tracking/experiment.py +309 -0
  134. morphml/tracking/logger.py +301 -0
  135. morphml/tracking/reporter.py +357 -0
  136. morphml/utils/__init__.py +6 -0
  137. morphml/utils/checkpoint.py +189 -0
  138. morphml/utils/comparison.py +390 -0
  139. morphml/utils/export.py +407 -0
  140. morphml/utils/progress.py +392 -0
  141. morphml/utils/validation.py +392 -0
  142. morphml/version.py +7 -0
  143. morphml/visualization/__init__.py +50 -0
  144. morphml/visualization/analytics.py +423 -0
  145. morphml/visualization/architecture_diagrams.py +353 -0
  146. morphml/visualization/architecture_plot.py +223 -0
  147. morphml/visualization/convergence_plot.py +174 -0
  148. morphml/visualization/crossover_viz.py +386 -0
  149. morphml/visualization/graph_viz.py +338 -0
  150. morphml/visualization/pareto_plot.py +149 -0
  151. morphml/visualization/plotly_dashboards.py +422 -0
  152. morphml/visualization/population.py +309 -0
  153. morphml/visualization/progress.py +260 -0
  154. morphml-1.0.0.dist-info/METADATA +434 -0
  155. morphml-1.0.0.dist-info/RECORD +158 -0
  156. morphml-1.0.0.dist-info/WHEEL +4 -0
  157. morphml-1.0.0.dist-info/entry_points.txt +3 -0
  158. morphml-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,552 @@
1
+ """GNN-based performance predictor using Graph Neural Networks.
2
+
3
+ Predicts architecture performance from graph structure without training.
4
+
5
+ Author: Eshan Roy <eshanized@proton.me>
6
+ Organization: TONMOY INFRASTRUCTURE & VISION
7
+ """
8
+
9
+ from typing import Any, Dict, List, Optional, Tuple
10
+
11
+ import numpy as np
12
+
13
+ from morphml.core.graph import ModelGraph
14
+ from morphml.logging_config import get_logger
15
+
16
+ logger = get_logger(__name__)
17
+
18
+
19
+ # Check for PyTorch dependencies
20
+ try:
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from torch_geometric.data import Batch, Data
25
+ from torch_geometric.nn import GATConv, global_max_pool, global_mean_pool
26
+
27
+ TORCH_AVAILABLE = True
28
+ except ImportError:
29
+ TORCH_AVAILABLE = False
30
+ logger.warning(
31
+ "PyTorch Geometric not available. GNNPredictor requires: "
32
+ "pip install torch torch-geometric"
33
+ )
34
+
35
+
36
+ if TORCH_AVAILABLE:
37
+
38
+ class ArchitectureGNN(nn.Module):
39
+ """
40
+ Graph Neural Network for architecture performance prediction.
41
+
42
+ Architecture:
43
+ - Graph Attention Network (GAT) for node embeddings
44
+ - Global pooling (mean + max)
45
+ - MLP predictor head
46
+
47
+ Input: ModelGraph
48
+ Output: Predicted accuracy (0-1)
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ node_feature_dim: int = 128,
54
+ hidden_dim: int = 256,
55
+ num_layers: int = 4,
56
+ num_heads: int = 4,
57
+ dropout: float = 0.3,
58
+ ):
59
+ """
60
+ Initialize GNN model.
61
+
62
+ Args:
63
+ node_feature_dim: Dimension of node features
64
+ hidden_dim: Hidden dimension for GNN layers
65
+ num_layers: Number of GNN layers
66
+ num_heads: Number of attention heads (for GAT)
67
+ dropout: Dropout rate
68
+ """
69
+ super().__init__()
70
+
71
+ self.node_feature_dim = node_feature_dim
72
+ self.hidden_dim = hidden_dim
73
+ self.num_layers = num_layers
74
+
75
+ # Graph attention layers
76
+ self.convs = nn.ModuleList()
77
+ self.batch_norms = nn.ModuleList()
78
+
79
+ # First layer
80
+ self.convs.append(
81
+ GATConv(
82
+ node_feature_dim,
83
+ hidden_dim // num_heads,
84
+ heads=num_heads,
85
+ dropout=dropout,
86
+ )
87
+ )
88
+ self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
89
+
90
+ # Hidden layers
91
+ for _ in range(num_layers - 1):
92
+ self.convs.append(
93
+ GATConv(
94
+ hidden_dim,
95
+ hidden_dim // num_heads,
96
+ heads=num_heads,
97
+ dropout=dropout,
98
+ )
99
+ )
100
+ self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
101
+
102
+ # Predictor head (mean + max pooling = 2 * hidden_dim)
103
+ self.predictor = nn.Sequential(
104
+ nn.Linear(2 * hidden_dim, hidden_dim),
105
+ nn.ReLU(),
106
+ nn.BatchNorm1d(hidden_dim),
107
+ nn.Dropout(dropout),
108
+ nn.Linear(hidden_dim, 64),
109
+ nn.ReLU(),
110
+ nn.Dropout(dropout),
111
+ nn.Linear(64, 1),
112
+ nn.Sigmoid(), # Output in [0, 1]
113
+ )
114
+
115
+ def forward(
116
+ self, x: torch.Tensor, edge_index: torch.Tensor, batch: torch.Tensor
117
+ ) -> torch.Tensor:
118
+ """
119
+ Forward pass.
120
+
121
+ Args:
122
+ x: Node features [num_nodes, node_feature_dim]
123
+ edge_index: Edge indices [2, num_edges]
124
+ batch: Batch assignment [num_nodes]
125
+
126
+ Returns:
127
+ Predicted accuracy [batch_size]
128
+ """
129
+ # Graph convolutions with residual connections
130
+ for i, (conv, bn) in enumerate(zip(self.convs, self.batch_norms)):
131
+ x_new = conv(x, edge_index)
132
+ x_new = bn(x_new)
133
+ x_new = F.elu(x_new)
134
+
135
+ # Residual connection (if dimensions match)
136
+ if i > 0 and x.shape[1] == x_new.shape[1]:
137
+ x = x + x_new
138
+ else:
139
+ x = x_new
140
+
141
+ # Global pooling (combine mean and max)
142
+ x_mean = global_mean_pool(x, batch)
143
+ x_max = global_max_pool(x, batch)
144
+ x = torch.cat([x_mean, x_max], dim=1)
145
+
146
+ # Predict
147
+ out = self.predictor(x)
148
+
149
+ return out.squeeze(-1)
150
+
151
+ class GNNPredictor:
152
+ """
153
+ Train and use GNN for architecture performance prediction.
154
+
155
+ This predictor learns to estimate architecture performance from
156
+ graph structure, enabling fast evaluation without training.
157
+
158
+ Target: 75%+ prediction accuracy on held-out architectures
159
+ Speedup: 100-1000x faster than full training
160
+
161
+ Args:
162
+ config: Configuration dict
163
+ - node_feature_dim: Node feature dimension (default: 128)
164
+ - hidden_dim: Hidden dimension (default: 256)
165
+ - num_layers: Number of GNN layers (default: 4)
166
+ - num_heads: Attention heads (default: 4)
167
+ - dropout: Dropout rate (default: 0.3)
168
+ - lr: Learning rate (default: 1e-3)
169
+ - weight_decay: L2 regularization (default: 1e-5)
170
+
171
+ Example:
172
+ >>> # Collect training data from past experiments
173
+ >>> train_data = [
174
+ ... (graph1, 0.92), # (ModelGraph, accuracy)
175
+ ... (graph2, 0.88),
176
+ ... # ... more examples
177
+ ... ]
178
+ >>>
179
+ >>> # Train predictor
180
+ >>> predictor = GNNPredictor({'num_layers': 4})
181
+ >>> predictor.train(train_data, num_epochs=100)
182
+ >>>
183
+ >>> # Predict on new architecture
184
+ >>> predicted_acc = predictor.predict(new_graph)
185
+ >>> print(f"Predicted accuracy: {predicted_acc:.2%}")
186
+ """
187
+
188
+ def __init__(self, config: Optional[Dict[str, Any]] = None):
189
+ """Initialize GNN predictor."""
190
+ if not TORCH_AVAILABLE:
191
+ raise ImportError(
192
+ "GNNPredictor requires PyTorch and PyTorch Geometric. "
193
+ "Install with: pip install torch torch-geometric"
194
+ )
195
+
196
+ self.config = config or {}
197
+
198
+ # Model configuration
199
+ self.model = ArchitectureGNN(
200
+ node_feature_dim=self.config.get("node_feature_dim", 128),
201
+ hidden_dim=self.config.get("hidden_dim", 256),
202
+ num_layers=self.config.get("num_layers", 4),
203
+ num_heads=self.config.get("num_heads", 4),
204
+ dropout=self.config.get("dropout", 0.3),
205
+ )
206
+
207
+ # Device
208
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
209
+ self.model = self.model.to(self.device)
210
+
211
+ logger.info(f"GNNPredictor initialized on device: {self.device}")
212
+
213
+ # Optimizer
214
+ self.optimizer = torch.optim.AdamW(
215
+ self.model.parameters(),
216
+ lr=self.config.get("lr", 1e-3),
217
+ weight_decay=self.config.get("weight_decay", 1e-5),
218
+ )
219
+
220
+ # Learning rate scheduler
221
+ self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
222
+ self.optimizer, mode="min", factor=0.5, patience=10, verbose=True
223
+ )
224
+
225
+ # Training stats
226
+ self.training_history: List[Dict[str, float]] = []
227
+ self.is_trained = False
228
+
229
+ def train(
230
+ self,
231
+ train_data: List[Tuple[ModelGraph, float]],
232
+ val_data: Optional[List[Tuple[ModelGraph, float]]] = None,
233
+ num_epochs: int = 100,
234
+ batch_size: int = 32,
235
+ early_stopping_patience: int = 20,
236
+ ) -> Dict[str, Any]:
237
+ """
238
+ Train GNN predictor on historical data.
239
+
240
+ Args:
241
+ train_data: List of (architecture, accuracy) pairs
242
+ val_data: Optional validation data
243
+ num_epochs: Maximum training epochs
244
+ batch_size: Batch size
245
+ early_stopping_patience: Stop if no improvement for N epochs
246
+
247
+ Returns:
248
+ Training statistics dict
249
+ """
250
+ logger.info(f"Training GNN predictor on {len(train_data)} examples")
251
+
252
+ # Convert to PyTorch Geometric Data objects
253
+ train_dataset = [self._graph_to_pyg_data(g, acc) for g, acc in train_data]
254
+
255
+ if val_data:
256
+ val_dataset = [self._graph_to_pyg_data(g, acc) for g, acc in val_data]
257
+ else:
258
+ # Use 20% of training data for validation
259
+ split_idx = int(0.8 * len(train_dataset))
260
+ val_dataset = train_dataset[split_idx:]
261
+ train_dataset = train_dataset[:split_idx]
262
+
263
+ logger.info(f"Split: {len(train_dataset)} train, {len(val_dataset)} validation")
264
+
265
+ # Training loop
266
+ best_val_loss = float("inf")
267
+ patience_counter = 0
268
+
269
+ for epoch in range(num_epochs):
270
+ # Train
271
+ train_loss, train_mae = self._train_epoch(train_dataset, batch_size)
272
+
273
+ # Validate
274
+ val_loss, val_mae = self._validate(val_dataset, batch_size)
275
+
276
+ # Learning rate scheduling
277
+ self.scheduler.step(val_loss)
278
+
279
+ # Track history
280
+ self.training_history.append(
281
+ {
282
+ "epoch": epoch,
283
+ "train_loss": train_loss,
284
+ "train_mae": train_mae,
285
+ "val_loss": val_loss,
286
+ "val_mae": val_mae,
287
+ }
288
+ )
289
+
290
+ # Log progress
291
+ if epoch % 10 == 0 or epoch == num_epochs - 1:
292
+ logger.info(
293
+ f"Epoch {epoch:3d}/{num_epochs}: "
294
+ f"train_loss={train_loss:.4f}, train_mae={train_mae:.4f}, "
295
+ f"val_loss={val_loss:.4f}, val_mae={val_mae:.4f}"
296
+ )
297
+
298
+ # Early stopping
299
+ if val_loss < best_val_loss:
300
+ best_val_loss = val_loss
301
+ patience_counter = 0
302
+ # Save best model
303
+ self.best_model_state = self.model.state_dict()
304
+ else:
305
+ patience_counter += 1
306
+ if patience_counter >= early_stopping_patience:
307
+ logger.info(f"Early stopping at epoch {epoch}")
308
+ break
309
+
310
+ # Restore best model
311
+ self.model.load_state_dict(self.best_model_state)
312
+ self.is_trained = True
313
+
314
+ logger.info(f"Training complete. Best val loss: {best_val_loss:.4f}")
315
+
316
+ return {
317
+ "best_val_loss": best_val_loss,
318
+ "num_epochs": epoch + 1,
319
+ "history": self.training_history,
320
+ }
321
+
322
+ def _train_epoch(self, dataset: List[Data], batch_size: int) -> Tuple[float, float]:
323
+ """Train one epoch."""
324
+ self.model.train()
325
+
326
+ # Shuffle and batch
327
+ indices = torch.randperm(len(dataset))
328
+ total_loss = 0.0
329
+ total_mae = 0.0
330
+ num_batches = 0
331
+
332
+ for i in range(0, len(dataset), batch_size):
333
+ batch_indices = indices[i : i + batch_size]
334
+ batch_data = [dataset[idx] for idx in batch_indices]
335
+
336
+ batch = Batch.from_data_list(batch_data).to(self.device)
337
+
338
+ # Forward
339
+ pred = self.model(batch.x, batch.edge_index, batch.batch)
340
+ loss = F.mse_loss(pred, batch.y)
341
+ mae = F.l1_loss(pred, batch.y)
342
+
343
+ # Backward
344
+ self.optimizer.zero_grad()
345
+ loss.backward()
346
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
347
+ self.optimizer.step()
348
+
349
+ total_loss += loss.item()
350
+ total_mae += mae.item()
351
+ num_batches += 1
352
+
353
+ return total_loss / num_batches, total_mae / num_batches
354
+
355
+ def _validate(self, dataset: List[Data], batch_size: int) -> Tuple[float, float]:
356
+ """Validate on dataset."""
357
+ self.model.eval()
358
+
359
+ total_loss = 0.0
360
+ total_mae = 0.0
361
+ num_batches = 0
362
+
363
+ with torch.no_grad():
364
+ for i in range(0, len(dataset), batch_size):
365
+ batch_data = dataset[i : i + batch_size]
366
+ batch = Batch.from_data_list(batch_data).to(self.device)
367
+
368
+ pred = self.model(batch.x, batch.edge_index, batch.batch)
369
+ loss = F.mse_loss(pred, batch.y)
370
+ mae = F.l1_loss(pred, batch.y)
371
+
372
+ total_loss += loss.item()
373
+ total_mae += mae.item()
374
+ num_batches += 1
375
+
376
+ return total_loss / num_batches, total_mae / num_batches
377
+
378
+ def predict(self, graph: ModelGraph) -> float:
379
+ """
380
+ Predict accuracy for architecture.
381
+
382
+ Args:
383
+ graph: ModelGraph to evaluate
384
+
385
+ Returns:
386
+ Predicted accuracy (0-1)
387
+ """
388
+ if not self.is_trained:
389
+ logger.warning("GNN predictor not trained, prediction may be inaccurate")
390
+
391
+ self.model.eval()
392
+
393
+ data = self._graph_to_pyg_data(graph, 0.0).to(self.device)
394
+
395
+ with torch.no_grad():
396
+ pred = self.model(data.x, data.edge_index, data.batch)
397
+
398
+ return float(pred.item())
399
+
400
+ def _graph_to_pyg_data(self, graph: ModelGraph, accuracy: float) -> Data:
401
+ """
402
+ Convert ModelGraph to PyTorch Geometric Data.
403
+
404
+ Node features encode:
405
+ - Operation type (one-hot)
406
+ - Hyperparameters (normalized)
407
+ - Positional encoding (layer depth)
408
+ """
409
+ # Extract nodes and edges
410
+ node_list = list(graph.nodes.values())
411
+ node_to_idx = {node.id: i for i, node in enumerate(node_list)}
412
+
413
+ # Node features
414
+ node_features = []
415
+ for i, node in enumerate(node_list):
416
+ feat = self._encode_node(node, i, len(node_list))
417
+ node_features.append(feat)
418
+
419
+ x = torch.tensor(node_features, dtype=torch.float)
420
+
421
+ # Pad/truncate to fixed dimension
422
+ if x.shape[1] < self.config.get("node_feature_dim", 128):
423
+ padding = torch.zeros(
424
+ x.shape[0],
425
+ self.config.get("node_feature_dim", 128) - x.shape[1],
426
+ )
427
+ x = torch.cat([x, padding], dim=1)
428
+ elif x.shape[1] > self.config.get("node_feature_dim", 128):
429
+ x = x[:, : self.config.get("node_feature_dim", 128)]
430
+
431
+ # Edge index
432
+ edge_list = []
433
+ for edge in graph.edges.values():
434
+ source_idx = node_to_idx[edge.source_id]
435
+ target_idx = node_to_idx[edge.target_id]
436
+ edge_list.append([source_idx, target_idx])
437
+
438
+ if edge_list:
439
+ edge_index = torch.tensor(edge_list, dtype=torch.long).t()
440
+ else:
441
+ # Empty graph - create self-loops
442
+ edge_index = torch.tensor(
443
+ [[i, i] for i in range(len(node_list))], dtype=torch.long
444
+ ).t()
445
+
446
+ # Label
447
+ y = torch.tensor([accuracy], dtype=torch.float)
448
+
449
+ # Batch indicator (for single graph)
450
+ batch = torch.zeros(x.shape[0], dtype=torch.long)
451
+
452
+ return Data(x=x, edge_index=edge_index, y=y, batch=batch)
453
+
454
+ def _encode_node(self, node, position: int, total_nodes: int) -> List[float]:
455
+ """
456
+ Encode node as feature vector.
457
+
458
+ Features:
459
+ - One-hot operation type (20 dims)
460
+ - Hyperparameters (variable)
461
+ - Positional encoding (2 dims)
462
+ """
463
+ features = []
464
+
465
+ # Operation type (one-hot)
466
+ operation_types = [
467
+ "input",
468
+ "output",
469
+ "conv2d",
470
+ "conv1d",
471
+ "depthwise_conv",
472
+ "maxpool",
473
+ "avgpool",
474
+ "globalavgpool",
475
+ "dense",
476
+ "linear",
477
+ "relu",
478
+ "gelu",
479
+ "sigmoid",
480
+ "tanh",
481
+ "batchnorm",
482
+ "layernorm",
483
+ "dropout",
484
+ "residual",
485
+ "concat",
486
+ "add",
487
+ ]
488
+
489
+ op_encoding = [0.0] * len(operation_types)
490
+ if node.operation in operation_types:
491
+ op_encoding[operation_types.index(node.operation)] = 1.0
492
+
493
+ features.extend(op_encoding)
494
+
495
+ # Hyperparameters (normalized)
496
+ if hasattr(node, "params") and node.params:
497
+ # Conv layers
498
+ if "filters" in node.params:
499
+ features.append(min(node.params["filters"] / 512.0, 1.0))
500
+ if "kernel_size" in node.params:
501
+ features.append(node.params["kernel_size"] / 7.0)
502
+ if "stride" in node.params:
503
+ features.append(node.params["stride"] / 2.0)
504
+
505
+ # Dense layers
506
+ if "units" in node.params:
507
+ features.append(min(node.params["units"] / 2048.0, 1.0))
508
+
509
+ # Dropout
510
+ if "rate" in node.params:
511
+ features.append(node.params["rate"])
512
+
513
+ # Positional encoding (normalized depth)
514
+ features.append(position / max(total_nodes, 1))
515
+ features.append(np.sin(position * 2 * np.pi / max(total_nodes, 1)))
516
+
517
+ return features
518
+
519
+ def save(self, path: str) -> None:
520
+ """Save model to file."""
521
+ torch.save(
522
+ {
523
+ "model_state": self.model.state_dict(),
524
+ "config": self.config,
525
+ "training_history": self.training_history,
526
+ "is_trained": self.is_trained,
527
+ },
528
+ path,
529
+ )
530
+ logger.info(f"GNN predictor saved to {path}")
531
+
532
+ def load(self, path: str) -> None:
533
+ """Load model from file."""
534
+ checkpoint = torch.load(path, map_location=self.device)
535
+ self.model.load_state_dict(checkpoint["model_state"])
536
+ self.config = checkpoint["config"]
537
+ self.training_history = checkpoint["training_history"]
538
+ self.is_trained = checkpoint["is_trained"]
539
+ logger.info(f"GNN predictor loaded from {path}")
540
+
541
+
542
+ # Fallback if PyTorch not available
543
+ else:
544
+
545
+ class GNNPredictor:
546
+ """Fallback GNN predictor (PyTorch not available)."""
547
+
548
+ def __init__(self, config: Optional[Dict[str, Any]] = None):
549
+ raise ImportError(
550
+ "GNNPredictor requires PyTorch and PyTorch Geometric. "
551
+ "Install with: pip install torch torch-geometric"
552
+ )