morphml 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of morphml might be problematic. Click here for more details.
- morphml/__init__.py +14 -0
- morphml/api/__init__.py +26 -0
- morphml/api/app.py +326 -0
- morphml/api/auth.py +193 -0
- morphml/api/client.py +338 -0
- morphml/api/models.py +132 -0
- morphml/api/rate_limit.py +192 -0
- morphml/benchmarking/__init__.py +36 -0
- morphml/benchmarking/comparison.py +430 -0
- morphml/benchmarks/__init__.py +56 -0
- morphml/benchmarks/comparator.py +409 -0
- morphml/benchmarks/datasets.py +280 -0
- morphml/benchmarks/metrics.py +199 -0
- morphml/benchmarks/openml_suite.py +201 -0
- morphml/benchmarks/problems.py +289 -0
- morphml/benchmarks/suite.py +318 -0
- morphml/cli/__init__.py +5 -0
- morphml/cli/commands/experiment.py +329 -0
- morphml/cli/main.py +457 -0
- morphml/cli/quickstart.py +312 -0
- morphml/config.py +278 -0
- morphml/constraints/__init__.py +19 -0
- morphml/constraints/handler.py +205 -0
- morphml/constraints/predicates.py +285 -0
- morphml/core/__init__.py +3 -0
- morphml/core/crossover.py +449 -0
- morphml/core/dsl/README.md +359 -0
- morphml/core/dsl/__init__.py +72 -0
- morphml/core/dsl/ast_nodes.py +364 -0
- morphml/core/dsl/compiler.py +318 -0
- morphml/core/dsl/layers.py +368 -0
- morphml/core/dsl/lexer.py +336 -0
- morphml/core/dsl/parser.py +455 -0
- morphml/core/dsl/search_space.py +386 -0
- morphml/core/dsl/syntax.py +199 -0
- morphml/core/dsl/type_system.py +361 -0
- morphml/core/dsl/validator.py +386 -0
- morphml/core/graph/__init__.py +40 -0
- morphml/core/graph/edge.py +124 -0
- morphml/core/graph/graph.py +507 -0
- morphml/core/graph/mutations.py +409 -0
- morphml/core/graph/node.py +196 -0
- morphml/core/graph/serialization.py +361 -0
- morphml/core/graph/visualization.py +431 -0
- morphml/core/objectives/__init__.py +20 -0
- morphml/core/search/__init__.py +33 -0
- morphml/core/search/individual.py +252 -0
- morphml/core/search/parameters.py +453 -0
- morphml/core/search/population.py +375 -0
- morphml/core/search/search_engine.py +340 -0
- morphml/distributed/__init__.py +76 -0
- morphml/distributed/fault_tolerance.py +497 -0
- morphml/distributed/health_monitor.py +348 -0
- morphml/distributed/master.py +709 -0
- morphml/distributed/proto/README.md +224 -0
- morphml/distributed/proto/__init__.py +74 -0
- morphml/distributed/proto/worker.proto +170 -0
- morphml/distributed/proto/worker_pb2.py +79 -0
- morphml/distributed/proto/worker_pb2_grpc.py +423 -0
- morphml/distributed/resource_manager.py +416 -0
- morphml/distributed/scheduler.py +567 -0
- morphml/distributed/storage/__init__.py +33 -0
- morphml/distributed/storage/artifacts.py +381 -0
- morphml/distributed/storage/cache.py +366 -0
- morphml/distributed/storage/checkpointing.py +329 -0
- morphml/distributed/storage/database.py +459 -0
- morphml/distributed/worker.py +549 -0
- morphml/evaluation/__init__.py +5 -0
- morphml/evaluation/heuristic.py +237 -0
- morphml/exceptions.py +55 -0
- morphml/execution/__init__.py +5 -0
- morphml/execution/local_executor.py +350 -0
- morphml/integrations/__init__.py +28 -0
- morphml/integrations/jax_adapter.py +206 -0
- morphml/integrations/pytorch_adapter.py +530 -0
- morphml/integrations/sklearn_adapter.py +206 -0
- morphml/integrations/tensorflow_adapter.py +230 -0
- morphml/logging_config.py +93 -0
- morphml/meta_learning/__init__.py +66 -0
- morphml/meta_learning/architecture_similarity.py +277 -0
- morphml/meta_learning/experiment_database.py +240 -0
- morphml/meta_learning/knowledge_base/__init__.py +19 -0
- morphml/meta_learning/knowledge_base/embedder.py +179 -0
- morphml/meta_learning/knowledge_base/knowledge_base.py +313 -0
- morphml/meta_learning/knowledge_base/meta_features.py +265 -0
- morphml/meta_learning/knowledge_base/vector_store.py +271 -0
- morphml/meta_learning/predictors/__init__.py +27 -0
- morphml/meta_learning/predictors/ensemble.py +221 -0
- morphml/meta_learning/predictors/gnn_predictor.py +552 -0
- morphml/meta_learning/predictors/learning_curve.py +231 -0
- morphml/meta_learning/predictors/proxy_metrics.py +261 -0
- morphml/meta_learning/strategy_evolution/__init__.py +27 -0
- morphml/meta_learning/strategy_evolution/adaptive_optimizer.py +226 -0
- morphml/meta_learning/strategy_evolution/bandit.py +276 -0
- morphml/meta_learning/strategy_evolution/portfolio.py +230 -0
- morphml/meta_learning/transfer.py +581 -0
- morphml/meta_learning/warm_start.py +286 -0
- morphml/optimizers/__init__.py +74 -0
- morphml/optimizers/adaptive_operators.py +399 -0
- morphml/optimizers/bayesian/__init__.py +52 -0
- morphml/optimizers/bayesian/acquisition.py +387 -0
- morphml/optimizers/bayesian/base.py +319 -0
- morphml/optimizers/bayesian/gaussian_process.py +635 -0
- morphml/optimizers/bayesian/smac.py +534 -0
- morphml/optimizers/bayesian/tpe.py +411 -0
- morphml/optimizers/differential_evolution.py +220 -0
- morphml/optimizers/evolutionary/__init__.py +61 -0
- morphml/optimizers/evolutionary/cma_es.py +416 -0
- morphml/optimizers/evolutionary/differential_evolution.py +556 -0
- morphml/optimizers/evolutionary/encoding.py +426 -0
- morphml/optimizers/evolutionary/particle_swarm.py +449 -0
- morphml/optimizers/genetic_algorithm.py +486 -0
- morphml/optimizers/gradient_based/__init__.py +22 -0
- morphml/optimizers/gradient_based/darts.py +550 -0
- morphml/optimizers/gradient_based/enas.py +585 -0
- morphml/optimizers/gradient_based/operations.py +474 -0
- morphml/optimizers/gradient_based/utils.py +601 -0
- morphml/optimizers/hill_climbing.py +169 -0
- morphml/optimizers/multi_objective/__init__.py +56 -0
- morphml/optimizers/multi_objective/indicators.py +504 -0
- morphml/optimizers/multi_objective/nsga2.py +647 -0
- morphml/optimizers/multi_objective/visualization.py +427 -0
- morphml/optimizers/nsga2.py +308 -0
- morphml/optimizers/random_search.py +172 -0
- morphml/optimizers/simulated_annealing.py +181 -0
- morphml/plugins/__init__.py +35 -0
- morphml/plugins/custom_evaluator_example.py +81 -0
- morphml/plugins/custom_optimizer_example.py +63 -0
- morphml/plugins/plugin_system.py +454 -0
- morphml/reports/__init__.py +30 -0
- morphml/reports/generator.py +362 -0
- morphml/tracking/__init__.py +7 -0
- morphml/tracking/experiment.py +309 -0
- morphml/tracking/logger.py +301 -0
- morphml/tracking/reporter.py +357 -0
- morphml/utils/__init__.py +6 -0
- morphml/utils/checkpoint.py +189 -0
- morphml/utils/comparison.py +390 -0
- morphml/utils/export.py +407 -0
- morphml/utils/progress.py +392 -0
- morphml/utils/validation.py +392 -0
- morphml/version.py +7 -0
- morphml/visualization/__init__.py +50 -0
- morphml/visualization/analytics.py +423 -0
- morphml/visualization/architecture_diagrams.py +353 -0
- morphml/visualization/architecture_plot.py +223 -0
- morphml/visualization/convergence_plot.py +174 -0
- morphml/visualization/crossover_viz.py +386 -0
- morphml/visualization/graph_viz.py +338 -0
- morphml/visualization/pareto_plot.py +149 -0
- morphml/visualization/plotly_dashboards.py +422 -0
- morphml/visualization/population.py +309 -0
- morphml/visualization/progress.py +260 -0
- morphml-1.0.0.dist-info/METADATA +434 -0
- morphml-1.0.0.dist-info/RECORD +158 -0
- morphml-1.0.0.dist-info/WHEEL +4 -0
- morphml-1.0.0.dist-info/entry_points.txt +3 -0
- morphml-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,552 @@
|
|
|
1
|
+
"""GNN-based performance predictor using Graph Neural Networks.
|
|
2
|
+
|
|
3
|
+
Predicts architecture performance from graph structure without training.
|
|
4
|
+
|
|
5
|
+
Author: Eshan Roy <eshanized@proton.me>
|
|
6
|
+
Organization: TONMOY INFRASTRUCTURE & VISION
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from morphml.core.graph import ModelGraph
|
|
14
|
+
from morphml.logging_config import get_logger
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# Check for PyTorch dependencies
|
|
20
|
+
try:
|
|
21
|
+
import torch
|
|
22
|
+
import torch.nn as nn
|
|
23
|
+
import torch.nn.functional as F
|
|
24
|
+
from torch_geometric.data import Batch, Data
|
|
25
|
+
from torch_geometric.nn import GATConv, global_max_pool, global_mean_pool
|
|
26
|
+
|
|
27
|
+
TORCH_AVAILABLE = True
|
|
28
|
+
except ImportError:
|
|
29
|
+
TORCH_AVAILABLE = False
|
|
30
|
+
logger.warning(
|
|
31
|
+
"PyTorch Geometric not available. GNNPredictor requires: "
|
|
32
|
+
"pip install torch torch-geometric"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
if TORCH_AVAILABLE:
|
|
37
|
+
|
|
38
|
+
class ArchitectureGNN(nn.Module):
|
|
39
|
+
"""
|
|
40
|
+
Graph Neural Network for architecture performance prediction.
|
|
41
|
+
|
|
42
|
+
Architecture:
|
|
43
|
+
- Graph Attention Network (GAT) for node embeddings
|
|
44
|
+
- Global pooling (mean + max)
|
|
45
|
+
- MLP predictor head
|
|
46
|
+
|
|
47
|
+
Input: ModelGraph
|
|
48
|
+
Output: Predicted accuracy (0-1)
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
node_feature_dim: int = 128,
|
|
54
|
+
hidden_dim: int = 256,
|
|
55
|
+
num_layers: int = 4,
|
|
56
|
+
num_heads: int = 4,
|
|
57
|
+
dropout: float = 0.3,
|
|
58
|
+
):
|
|
59
|
+
"""
|
|
60
|
+
Initialize GNN model.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
node_feature_dim: Dimension of node features
|
|
64
|
+
hidden_dim: Hidden dimension for GNN layers
|
|
65
|
+
num_layers: Number of GNN layers
|
|
66
|
+
num_heads: Number of attention heads (for GAT)
|
|
67
|
+
dropout: Dropout rate
|
|
68
|
+
"""
|
|
69
|
+
super().__init__()
|
|
70
|
+
|
|
71
|
+
self.node_feature_dim = node_feature_dim
|
|
72
|
+
self.hidden_dim = hidden_dim
|
|
73
|
+
self.num_layers = num_layers
|
|
74
|
+
|
|
75
|
+
# Graph attention layers
|
|
76
|
+
self.convs = nn.ModuleList()
|
|
77
|
+
self.batch_norms = nn.ModuleList()
|
|
78
|
+
|
|
79
|
+
# First layer
|
|
80
|
+
self.convs.append(
|
|
81
|
+
GATConv(
|
|
82
|
+
node_feature_dim,
|
|
83
|
+
hidden_dim // num_heads,
|
|
84
|
+
heads=num_heads,
|
|
85
|
+
dropout=dropout,
|
|
86
|
+
)
|
|
87
|
+
)
|
|
88
|
+
self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
|
|
89
|
+
|
|
90
|
+
# Hidden layers
|
|
91
|
+
for _ in range(num_layers - 1):
|
|
92
|
+
self.convs.append(
|
|
93
|
+
GATConv(
|
|
94
|
+
hidden_dim,
|
|
95
|
+
hidden_dim // num_heads,
|
|
96
|
+
heads=num_heads,
|
|
97
|
+
dropout=dropout,
|
|
98
|
+
)
|
|
99
|
+
)
|
|
100
|
+
self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
|
|
101
|
+
|
|
102
|
+
# Predictor head (mean + max pooling = 2 * hidden_dim)
|
|
103
|
+
self.predictor = nn.Sequential(
|
|
104
|
+
nn.Linear(2 * hidden_dim, hidden_dim),
|
|
105
|
+
nn.ReLU(),
|
|
106
|
+
nn.BatchNorm1d(hidden_dim),
|
|
107
|
+
nn.Dropout(dropout),
|
|
108
|
+
nn.Linear(hidden_dim, 64),
|
|
109
|
+
nn.ReLU(),
|
|
110
|
+
nn.Dropout(dropout),
|
|
111
|
+
nn.Linear(64, 1),
|
|
112
|
+
nn.Sigmoid(), # Output in [0, 1]
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def forward(
|
|
116
|
+
self, x: torch.Tensor, edge_index: torch.Tensor, batch: torch.Tensor
|
|
117
|
+
) -> torch.Tensor:
|
|
118
|
+
"""
|
|
119
|
+
Forward pass.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
x: Node features [num_nodes, node_feature_dim]
|
|
123
|
+
edge_index: Edge indices [2, num_edges]
|
|
124
|
+
batch: Batch assignment [num_nodes]
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
Predicted accuracy [batch_size]
|
|
128
|
+
"""
|
|
129
|
+
# Graph convolutions with residual connections
|
|
130
|
+
for i, (conv, bn) in enumerate(zip(self.convs, self.batch_norms)):
|
|
131
|
+
x_new = conv(x, edge_index)
|
|
132
|
+
x_new = bn(x_new)
|
|
133
|
+
x_new = F.elu(x_new)
|
|
134
|
+
|
|
135
|
+
# Residual connection (if dimensions match)
|
|
136
|
+
if i > 0 and x.shape[1] == x_new.shape[1]:
|
|
137
|
+
x = x + x_new
|
|
138
|
+
else:
|
|
139
|
+
x = x_new
|
|
140
|
+
|
|
141
|
+
# Global pooling (combine mean and max)
|
|
142
|
+
x_mean = global_mean_pool(x, batch)
|
|
143
|
+
x_max = global_max_pool(x, batch)
|
|
144
|
+
x = torch.cat([x_mean, x_max], dim=1)
|
|
145
|
+
|
|
146
|
+
# Predict
|
|
147
|
+
out = self.predictor(x)
|
|
148
|
+
|
|
149
|
+
return out.squeeze(-1)
|
|
150
|
+
|
|
151
|
+
class GNNPredictor:
|
|
152
|
+
"""
|
|
153
|
+
Train and use GNN for architecture performance prediction.
|
|
154
|
+
|
|
155
|
+
This predictor learns to estimate architecture performance from
|
|
156
|
+
graph structure, enabling fast evaluation without training.
|
|
157
|
+
|
|
158
|
+
Target: 75%+ prediction accuracy on held-out architectures
|
|
159
|
+
Speedup: 100-1000x faster than full training
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
config: Configuration dict
|
|
163
|
+
- node_feature_dim: Node feature dimension (default: 128)
|
|
164
|
+
- hidden_dim: Hidden dimension (default: 256)
|
|
165
|
+
- num_layers: Number of GNN layers (default: 4)
|
|
166
|
+
- num_heads: Attention heads (default: 4)
|
|
167
|
+
- dropout: Dropout rate (default: 0.3)
|
|
168
|
+
- lr: Learning rate (default: 1e-3)
|
|
169
|
+
- weight_decay: L2 regularization (default: 1e-5)
|
|
170
|
+
|
|
171
|
+
Example:
|
|
172
|
+
>>> # Collect training data from past experiments
|
|
173
|
+
>>> train_data = [
|
|
174
|
+
... (graph1, 0.92), # (ModelGraph, accuracy)
|
|
175
|
+
... (graph2, 0.88),
|
|
176
|
+
... # ... more examples
|
|
177
|
+
... ]
|
|
178
|
+
>>>
|
|
179
|
+
>>> # Train predictor
|
|
180
|
+
>>> predictor = GNNPredictor({'num_layers': 4})
|
|
181
|
+
>>> predictor.train(train_data, num_epochs=100)
|
|
182
|
+
>>>
|
|
183
|
+
>>> # Predict on new architecture
|
|
184
|
+
>>> predicted_acc = predictor.predict(new_graph)
|
|
185
|
+
>>> print(f"Predicted accuracy: {predicted_acc:.2%}")
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
|
189
|
+
"""Initialize GNN predictor."""
|
|
190
|
+
if not TORCH_AVAILABLE:
|
|
191
|
+
raise ImportError(
|
|
192
|
+
"GNNPredictor requires PyTorch and PyTorch Geometric. "
|
|
193
|
+
"Install with: pip install torch torch-geometric"
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
self.config = config or {}
|
|
197
|
+
|
|
198
|
+
# Model configuration
|
|
199
|
+
self.model = ArchitectureGNN(
|
|
200
|
+
node_feature_dim=self.config.get("node_feature_dim", 128),
|
|
201
|
+
hidden_dim=self.config.get("hidden_dim", 256),
|
|
202
|
+
num_layers=self.config.get("num_layers", 4),
|
|
203
|
+
num_heads=self.config.get("num_heads", 4),
|
|
204
|
+
dropout=self.config.get("dropout", 0.3),
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# Device
|
|
208
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
209
|
+
self.model = self.model.to(self.device)
|
|
210
|
+
|
|
211
|
+
logger.info(f"GNNPredictor initialized on device: {self.device}")
|
|
212
|
+
|
|
213
|
+
# Optimizer
|
|
214
|
+
self.optimizer = torch.optim.AdamW(
|
|
215
|
+
self.model.parameters(),
|
|
216
|
+
lr=self.config.get("lr", 1e-3),
|
|
217
|
+
weight_decay=self.config.get("weight_decay", 1e-5),
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# Learning rate scheduler
|
|
221
|
+
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
222
|
+
self.optimizer, mode="min", factor=0.5, patience=10, verbose=True
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# Training stats
|
|
226
|
+
self.training_history: List[Dict[str, float]] = []
|
|
227
|
+
self.is_trained = False
|
|
228
|
+
|
|
229
|
+
def train(
|
|
230
|
+
self,
|
|
231
|
+
train_data: List[Tuple[ModelGraph, float]],
|
|
232
|
+
val_data: Optional[List[Tuple[ModelGraph, float]]] = None,
|
|
233
|
+
num_epochs: int = 100,
|
|
234
|
+
batch_size: int = 32,
|
|
235
|
+
early_stopping_patience: int = 20,
|
|
236
|
+
) -> Dict[str, Any]:
|
|
237
|
+
"""
|
|
238
|
+
Train GNN predictor on historical data.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
train_data: List of (architecture, accuracy) pairs
|
|
242
|
+
val_data: Optional validation data
|
|
243
|
+
num_epochs: Maximum training epochs
|
|
244
|
+
batch_size: Batch size
|
|
245
|
+
early_stopping_patience: Stop if no improvement for N epochs
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
Training statistics dict
|
|
249
|
+
"""
|
|
250
|
+
logger.info(f"Training GNN predictor on {len(train_data)} examples")
|
|
251
|
+
|
|
252
|
+
# Convert to PyTorch Geometric Data objects
|
|
253
|
+
train_dataset = [self._graph_to_pyg_data(g, acc) for g, acc in train_data]
|
|
254
|
+
|
|
255
|
+
if val_data:
|
|
256
|
+
val_dataset = [self._graph_to_pyg_data(g, acc) for g, acc in val_data]
|
|
257
|
+
else:
|
|
258
|
+
# Use 20% of training data for validation
|
|
259
|
+
split_idx = int(0.8 * len(train_dataset))
|
|
260
|
+
val_dataset = train_dataset[split_idx:]
|
|
261
|
+
train_dataset = train_dataset[:split_idx]
|
|
262
|
+
|
|
263
|
+
logger.info(f"Split: {len(train_dataset)} train, {len(val_dataset)} validation")
|
|
264
|
+
|
|
265
|
+
# Training loop
|
|
266
|
+
best_val_loss = float("inf")
|
|
267
|
+
patience_counter = 0
|
|
268
|
+
|
|
269
|
+
for epoch in range(num_epochs):
|
|
270
|
+
# Train
|
|
271
|
+
train_loss, train_mae = self._train_epoch(train_dataset, batch_size)
|
|
272
|
+
|
|
273
|
+
# Validate
|
|
274
|
+
val_loss, val_mae = self._validate(val_dataset, batch_size)
|
|
275
|
+
|
|
276
|
+
# Learning rate scheduling
|
|
277
|
+
self.scheduler.step(val_loss)
|
|
278
|
+
|
|
279
|
+
# Track history
|
|
280
|
+
self.training_history.append(
|
|
281
|
+
{
|
|
282
|
+
"epoch": epoch,
|
|
283
|
+
"train_loss": train_loss,
|
|
284
|
+
"train_mae": train_mae,
|
|
285
|
+
"val_loss": val_loss,
|
|
286
|
+
"val_mae": val_mae,
|
|
287
|
+
}
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# Log progress
|
|
291
|
+
if epoch % 10 == 0 or epoch == num_epochs - 1:
|
|
292
|
+
logger.info(
|
|
293
|
+
f"Epoch {epoch:3d}/{num_epochs}: "
|
|
294
|
+
f"train_loss={train_loss:.4f}, train_mae={train_mae:.4f}, "
|
|
295
|
+
f"val_loss={val_loss:.4f}, val_mae={val_mae:.4f}"
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
# Early stopping
|
|
299
|
+
if val_loss < best_val_loss:
|
|
300
|
+
best_val_loss = val_loss
|
|
301
|
+
patience_counter = 0
|
|
302
|
+
# Save best model
|
|
303
|
+
self.best_model_state = self.model.state_dict()
|
|
304
|
+
else:
|
|
305
|
+
patience_counter += 1
|
|
306
|
+
if patience_counter >= early_stopping_patience:
|
|
307
|
+
logger.info(f"Early stopping at epoch {epoch}")
|
|
308
|
+
break
|
|
309
|
+
|
|
310
|
+
# Restore best model
|
|
311
|
+
self.model.load_state_dict(self.best_model_state)
|
|
312
|
+
self.is_trained = True
|
|
313
|
+
|
|
314
|
+
logger.info(f"Training complete. Best val loss: {best_val_loss:.4f}")
|
|
315
|
+
|
|
316
|
+
return {
|
|
317
|
+
"best_val_loss": best_val_loss,
|
|
318
|
+
"num_epochs": epoch + 1,
|
|
319
|
+
"history": self.training_history,
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
def _train_epoch(self, dataset: List[Data], batch_size: int) -> Tuple[float, float]:
|
|
323
|
+
"""Train one epoch."""
|
|
324
|
+
self.model.train()
|
|
325
|
+
|
|
326
|
+
# Shuffle and batch
|
|
327
|
+
indices = torch.randperm(len(dataset))
|
|
328
|
+
total_loss = 0.0
|
|
329
|
+
total_mae = 0.0
|
|
330
|
+
num_batches = 0
|
|
331
|
+
|
|
332
|
+
for i in range(0, len(dataset), batch_size):
|
|
333
|
+
batch_indices = indices[i : i + batch_size]
|
|
334
|
+
batch_data = [dataset[idx] for idx in batch_indices]
|
|
335
|
+
|
|
336
|
+
batch = Batch.from_data_list(batch_data).to(self.device)
|
|
337
|
+
|
|
338
|
+
# Forward
|
|
339
|
+
pred = self.model(batch.x, batch.edge_index, batch.batch)
|
|
340
|
+
loss = F.mse_loss(pred, batch.y)
|
|
341
|
+
mae = F.l1_loss(pred, batch.y)
|
|
342
|
+
|
|
343
|
+
# Backward
|
|
344
|
+
self.optimizer.zero_grad()
|
|
345
|
+
loss.backward()
|
|
346
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
|
347
|
+
self.optimizer.step()
|
|
348
|
+
|
|
349
|
+
total_loss += loss.item()
|
|
350
|
+
total_mae += mae.item()
|
|
351
|
+
num_batches += 1
|
|
352
|
+
|
|
353
|
+
return total_loss / num_batches, total_mae / num_batches
|
|
354
|
+
|
|
355
|
+
def _validate(self, dataset: List[Data], batch_size: int) -> Tuple[float, float]:
|
|
356
|
+
"""Validate on dataset."""
|
|
357
|
+
self.model.eval()
|
|
358
|
+
|
|
359
|
+
total_loss = 0.0
|
|
360
|
+
total_mae = 0.0
|
|
361
|
+
num_batches = 0
|
|
362
|
+
|
|
363
|
+
with torch.no_grad():
|
|
364
|
+
for i in range(0, len(dataset), batch_size):
|
|
365
|
+
batch_data = dataset[i : i + batch_size]
|
|
366
|
+
batch = Batch.from_data_list(batch_data).to(self.device)
|
|
367
|
+
|
|
368
|
+
pred = self.model(batch.x, batch.edge_index, batch.batch)
|
|
369
|
+
loss = F.mse_loss(pred, batch.y)
|
|
370
|
+
mae = F.l1_loss(pred, batch.y)
|
|
371
|
+
|
|
372
|
+
total_loss += loss.item()
|
|
373
|
+
total_mae += mae.item()
|
|
374
|
+
num_batches += 1
|
|
375
|
+
|
|
376
|
+
return total_loss / num_batches, total_mae / num_batches
|
|
377
|
+
|
|
378
|
+
def predict(self, graph: ModelGraph) -> float:
|
|
379
|
+
"""
|
|
380
|
+
Predict accuracy for architecture.
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
graph: ModelGraph to evaluate
|
|
384
|
+
|
|
385
|
+
Returns:
|
|
386
|
+
Predicted accuracy (0-1)
|
|
387
|
+
"""
|
|
388
|
+
if not self.is_trained:
|
|
389
|
+
logger.warning("GNN predictor not trained, prediction may be inaccurate")
|
|
390
|
+
|
|
391
|
+
self.model.eval()
|
|
392
|
+
|
|
393
|
+
data = self._graph_to_pyg_data(graph, 0.0).to(self.device)
|
|
394
|
+
|
|
395
|
+
with torch.no_grad():
|
|
396
|
+
pred = self.model(data.x, data.edge_index, data.batch)
|
|
397
|
+
|
|
398
|
+
return float(pred.item())
|
|
399
|
+
|
|
400
|
+
def _graph_to_pyg_data(self, graph: ModelGraph, accuracy: float) -> Data:
|
|
401
|
+
"""
|
|
402
|
+
Convert ModelGraph to PyTorch Geometric Data.
|
|
403
|
+
|
|
404
|
+
Node features encode:
|
|
405
|
+
- Operation type (one-hot)
|
|
406
|
+
- Hyperparameters (normalized)
|
|
407
|
+
- Positional encoding (layer depth)
|
|
408
|
+
"""
|
|
409
|
+
# Extract nodes and edges
|
|
410
|
+
node_list = list(graph.nodes.values())
|
|
411
|
+
node_to_idx = {node.id: i for i, node in enumerate(node_list)}
|
|
412
|
+
|
|
413
|
+
# Node features
|
|
414
|
+
node_features = []
|
|
415
|
+
for i, node in enumerate(node_list):
|
|
416
|
+
feat = self._encode_node(node, i, len(node_list))
|
|
417
|
+
node_features.append(feat)
|
|
418
|
+
|
|
419
|
+
x = torch.tensor(node_features, dtype=torch.float)
|
|
420
|
+
|
|
421
|
+
# Pad/truncate to fixed dimension
|
|
422
|
+
if x.shape[1] < self.config.get("node_feature_dim", 128):
|
|
423
|
+
padding = torch.zeros(
|
|
424
|
+
x.shape[0],
|
|
425
|
+
self.config.get("node_feature_dim", 128) - x.shape[1],
|
|
426
|
+
)
|
|
427
|
+
x = torch.cat([x, padding], dim=1)
|
|
428
|
+
elif x.shape[1] > self.config.get("node_feature_dim", 128):
|
|
429
|
+
x = x[:, : self.config.get("node_feature_dim", 128)]
|
|
430
|
+
|
|
431
|
+
# Edge index
|
|
432
|
+
edge_list = []
|
|
433
|
+
for edge in graph.edges.values():
|
|
434
|
+
source_idx = node_to_idx[edge.source_id]
|
|
435
|
+
target_idx = node_to_idx[edge.target_id]
|
|
436
|
+
edge_list.append([source_idx, target_idx])
|
|
437
|
+
|
|
438
|
+
if edge_list:
|
|
439
|
+
edge_index = torch.tensor(edge_list, dtype=torch.long).t()
|
|
440
|
+
else:
|
|
441
|
+
# Empty graph - create self-loops
|
|
442
|
+
edge_index = torch.tensor(
|
|
443
|
+
[[i, i] for i in range(len(node_list))], dtype=torch.long
|
|
444
|
+
).t()
|
|
445
|
+
|
|
446
|
+
# Label
|
|
447
|
+
y = torch.tensor([accuracy], dtype=torch.float)
|
|
448
|
+
|
|
449
|
+
# Batch indicator (for single graph)
|
|
450
|
+
batch = torch.zeros(x.shape[0], dtype=torch.long)
|
|
451
|
+
|
|
452
|
+
return Data(x=x, edge_index=edge_index, y=y, batch=batch)
|
|
453
|
+
|
|
454
|
+
def _encode_node(self, node, position: int, total_nodes: int) -> List[float]:
|
|
455
|
+
"""
|
|
456
|
+
Encode node as feature vector.
|
|
457
|
+
|
|
458
|
+
Features:
|
|
459
|
+
- One-hot operation type (20 dims)
|
|
460
|
+
- Hyperparameters (variable)
|
|
461
|
+
- Positional encoding (2 dims)
|
|
462
|
+
"""
|
|
463
|
+
features = []
|
|
464
|
+
|
|
465
|
+
# Operation type (one-hot)
|
|
466
|
+
operation_types = [
|
|
467
|
+
"input",
|
|
468
|
+
"output",
|
|
469
|
+
"conv2d",
|
|
470
|
+
"conv1d",
|
|
471
|
+
"depthwise_conv",
|
|
472
|
+
"maxpool",
|
|
473
|
+
"avgpool",
|
|
474
|
+
"globalavgpool",
|
|
475
|
+
"dense",
|
|
476
|
+
"linear",
|
|
477
|
+
"relu",
|
|
478
|
+
"gelu",
|
|
479
|
+
"sigmoid",
|
|
480
|
+
"tanh",
|
|
481
|
+
"batchnorm",
|
|
482
|
+
"layernorm",
|
|
483
|
+
"dropout",
|
|
484
|
+
"residual",
|
|
485
|
+
"concat",
|
|
486
|
+
"add",
|
|
487
|
+
]
|
|
488
|
+
|
|
489
|
+
op_encoding = [0.0] * len(operation_types)
|
|
490
|
+
if node.operation in operation_types:
|
|
491
|
+
op_encoding[operation_types.index(node.operation)] = 1.0
|
|
492
|
+
|
|
493
|
+
features.extend(op_encoding)
|
|
494
|
+
|
|
495
|
+
# Hyperparameters (normalized)
|
|
496
|
+
if hasattr(node, "params") and node.params:
|
|
497
|
+
# Conv layers
|
|
498
|
+
if "filters" in node.params:
|
|
499
|
+
features.append(min(node.params["filters"] / 512.0, 1.0))
|
|
500
|
+
if "kernel_size" in node.params:
|
|
501
|
+
features.append(node.params["kernel_size"] / 7.0)
|
|
502
|
+
if "stride" in node.params:
|
|
503
|
+
features.append(node.params["stride"] / 2.0)
|
|
504
|
+
|
|
505
|
+
# Dense layers
|
|
506
|
+
if "units" in node.params:
|
|
507
|
+
features.append(min(node.params["units"] / 2048.0, 1.0))
|
|
508
|
+
|
|
509
|
+
# Dropout
|
|
510
|
+
if "rate" in node.params:
|
|
511
|
+
features.append(node.params["rate"])
|
|
512
|
+
|
|
513
|
+
# Positional encoding (normalized depth)
|
|
514
|
+
features.append(position / max(total_nodes, 1))
|
|
515
|
+
features.append(np.sin(position * 2 * np.pi / max(total_nodes, 1)))
|
|
516
|
+
|
|
517
|
+
return features
|
|
518
|
+
|
|
519
|
+
def save(self, path: str) -> None:
|
|
520
|
+
"""Save model to file."""
|
|
521
|
+
torch.save(
|
|
522
|
+
{
|
|
523
|
+
"model_state": self.model.state_dict(),
|
|
524
|
+
"config": self.config,
|
|
525
|
+
"training_history": self.training_history,
|
|
526
|
+
"is_trained": self.is_trained,
|
|
527
|
+
},
|
|
528
|
+
path,
|
|
529
|
+
)
|
|
530
|
+
logger.info(f"GNN predictor saved to {path}")
|
|
531
|
+
|
|
532
|
+
def load(self, path: str) -> None:
|
|
533
|
+
"""Load model from file."""
|
|
534
|
+
checkpoint = torch.load(path, map_location=self.device)
|
|
535
|
+
self.model.load_state_dict(checkpoint["model_state"])
|
|
536
|
+
self.config = checkpoint["config"]
|
|
537
|
+
self.training_history = checkpoint["training_history"]
|
|
538
|
+
self.is_trained = checkpoint["is_trained"]
|
|
539
|
+
logger.info(f"GNN predictor loaded from {path}")
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
# Fallback if PyTorch not available
|
|
543
|
+
else:
|
|
544
|
+
|
|
545
|
+
class GNNPredictor:
|
|
546
|
+
"""Fallback GNN predictor (PyTorch not available)."""
|
|
547
|
+
|
|
548
|
+
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
|
549
|
+
raise ImportError(
|
|
550
|
+
"GNNPredictor requires PyTorch and PyTorch Geometric. "
|
|
551
|
+
"Install with: pip install torch torch-geometric"
|
|
552
|
+
)
|