morphml 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of morphml might be problematic. Click here for more details.

Files changed (158) hide show
  1. morphml/__init__.py +14 -0
  2. morphml/api/__init__.py +26 -0
  3. morphml/api/app.py +326 -0
  4. morphml/api/auth.py +193 -0
  5. morphml/api/client.py +338 -0
  6. morphml/api/models.py +132 -0
  7. morphml/api/rate_limit.py +192 -0
  8. morphml/benchmarking/__init__.py +36 -0
  9. morphml/benchmarking/comparison.py +430 -0
  10. morphml/benchmarks/__init__.py +56 -0
  11. morphml/benchmarks/comparator.py +409 -0
  12. morphml/benchmarks/datasets.py +280 -0
  13. morphml/benchmarks/metrics.py +199 -0
  14. morphml/benchmarks/openml_suite.py +201 -0
  15. morphml/benchmarks/problems.py +289 -0
  16. morphml/benchmarks/suite.py +318 -0
  17. morphml/cli/__init__.py +5 -0
  18. morphml/cli/commands/experiment.py +329 -0
  19. morphml/cli/main.py +457 -0
  20. morphml/cli/quickstart.py +312 -0
  21. morphml/config.py +278 -0
  22. morphml/constraints/__init__.py +19 -0
  23. morphml/constraints/handler.py +205 -0
  24. morphml/constraints/predicates.py +285 -0
  25. morphml/core/__init__.py +3 -0
  26. morphml/core/crossover.py +449 -0
  27. morphml/core/dsl/README.md +359 -0
  28. morphml/core/dsl/__init__.py +72 -0
  29. morphml/core/dsl/ast_nodes.py +364 -0
  30. morphml/core/dsl/compiler.py +318 -0
  31. morphml/core/dsl/layers.py +368 -0
  32. morphml/core/dsl/lexer.py +336 -0
  33. morphml/core/dsl/parser.py +455 -0
  34. morphml/core/dsl/search_space.py +386 -0
  35. morphml/core/dsl/syntax.py +199 -0
  36. morphml/core/dsl/type_system.py +361 -0
  37. morphml/core/dsl/validator.py +386 -0
  38. morphml/core/graph/__init__.py +40 -0
  39. morphml/core/graph/edge.py +124 -0
  40. morphml/core/graph/graph.py +507 -0
  41. morphml/core/graph/mutations.py +409 -0
  42. morphml/core/graph/node.py +196 -0
  43. morphml/core/graph/serialization.py +361 -0
  44. morphml/core/graph/visualization.py +431 -0
  45. morphml/core/objectives/__init__.py +20 -0
  46. morphml/core/search/__init__.py +33 -0
  47. morphml/core/search/individual.py +252 -0
  48. morphml/core/search/parameters.py +453 -0
  49. morphml/core/search/population.py +375 -0
  50. morphml/core/search/search_engine.py +340 -0
  51. morphml/distributed/__init__.py +76 -0
  52. morphml/distributed/fault_tolerance.py +497 -0
  53. morphml/distributed/health_monitor.py +348 -0
  54. morphml/distributed/master.py +709 -0
  55. morphml/distributed/proto/README.md +224 -0
  56. morphml/distributed/proto/__init__.py +74 -0
  57. morphml/distributed/proto/worker.proto +170 -0
  58. morphml/distributed/proto/worker_pb2.py +79 -0
  59. morphml/distributed/proto/worker_pb2_grpc.py +423 -0
  60. morphml/distributed/resource_manager.py +416 -0
  61. morphml/distributed/scheduler.py +567 -0
  62. morphml/distributed/storage/__init__.py +33 -0
  63. morphml/distributed/storage/artifacts.py +381 -0
  64. morphml/distributed/storage/cache.py +366 -0
  65. morphml/distributed/storage/checkpointing.py +329 -0
  66. morphml/distributed/storage/database.py +459 -0
  67. morphml/distributed/worker.py +549 -0
  68. morphml/evaluation/__init__.py +5 -0
  69. morphml/evaluation/heuristic.py +237 -0
  70. morphml/exceptions.py +55 -0
  71. morphml/execution/__init__.py +5 -0
  72. morphml/execution/local_executor.py +350 -0
  73. morphml/integrations/__init__.py +28 -0
  74. morphml/integrations/jax_adapter.py +206 -0
  75. morphml/integrations/pytorch_adapter.py +530 -0
  76. morphml/integrations/sklearn_adapter.py +206 -0
  77. morphml/integrations/tensorflow_adapter.py +230 -0
  78. morphml/logging_config.py +93 -0
  79. morphml/meta_learning/__init__.py +66 -0
  80. morphml/meta_learning/architecture_similarity.py +277 -0
  81. morphml/meta_learning/experiment_database.py +240 -0
  82. morphml/meta_learning/knowledge_base/__init__.py +19 -0
  83. morphml/meta_learning/knowledge_base/embedder.py +179 -0
  84. morphml/meta_learning/knowledge_base/knowledge_base.py +313 -0
  85. morphml/meta_learning/knowledge_base/meta_features.py +265 -0
  86. morphml/meta_learning/knowledge_base/vector_store.py +271 -0
  87. morphml/meta_learning/predictors/__init__.py +27 -0
  88. morphml/meta_learning/predictors/ensemble.py +221 -0
  89. morphml/meta_learning/predictors/gnn_predictor.py +552 -0
  90. morphml/meta_learning/predictors/learning_curve.py +231 -0
  91. morphml/meta_learning/predictors/proxy_metrics.py +261 -0
  92. morphml/meta_learning/strategy_evolution/__init__.py +27 -0
  93. morphml/meta_learning/strategy_evolution/adaptive_optimizer.py +226 -0
  94. morphml/meta_learning/strategy_evolution/bandit.py +276 -0
  95. morphml/meta_learning/strategy_evolution/portfolio.py +230 -0
  96. morphml/meta_learning/transfer.py +581 -0
  97. morphml/meta_learning/warm_start.py +286 -0
  98. morphml/optimizers/__init__.py +74 -0
  99. morphml/optimizers/adaptive_operators.py +399 -0
  100. morphml/optimizers/bayesian/__init__.py +52 -0
  101. morphml/optimizers/bayesian/acquisition.py +387 -0
  102. morphml/optimizers/bayesian/base.py +319 -0
  103. morphml/optimizers/bayesian/gaussian_process.py +635 -0
  104. morphml/optimizers/bayesian/smac.py +534 -0
  105. morphml/optimizers/bayesian/tpe.py +411 -0
  106. morphml/optimizers/differential_evolution.py +220 -0
  107. morphml/optimizers/evolutionary/__init__.py +61 -0
  108. morphml/optimizers/evolutionary/cma_es.py +416 -0
  109. morphml/optimizers/evolutionary/differential_evolution.py +556 -0
  110. morphml/optimizers/evolutionary/encoding.py +426 -0
  111. morphml/optimizers/evolutionary/particle_swarm.py +449 -0
  112. morphml/optimizers/genetic_algorithm.py +486 -0
  113. morphml/optimizers/gradient_based/__init__.py +22 -0
  114. morphml/optimizers/gradient_based/darts.py +550 -0
  115. morphml/optimizers/gradient_based/enas.py +585 -0
  116. morphml/optimizers/gradient_based/operations.py +474 -0
  117. morphml/optimizers/gradient_based/utils.py +601 -0
  118. morphml/optimizers/hill_climbing.py +169 -0
  119. morphml/optimizers/multi_objective/__init__.py +56 -0
  120. morphml/optimizers/multi_objective/indicators.py +504 -0
  121. morphml/optimizers/multi_objective/nsga2.py +647 -0
  122. morphml/optimizers/multi_objective/visualization.py +427 -0
  123. morphml/optimizers/nsga2.py +308 -0
  124. morphml/optimizers/random_search.py +172 -0
  125. morphml/optimizers/simulated_annealing.py +181 -0
  126. morphml/plugins/__init__.py +35 -0
  127. morphml/plugins/custom_evaluator_example.py +81 -0
  128. morphml/plugins/custom_optimizer_example.py +63 -0
  129. morphml/plugins/plugin_system.py +454 -0
  130. morphml/reports/__init__.py +30 -0
  131. morphml/reports/generator.py +362 -0
  132. morphml/tracking/__init__.py +7 -0
  133. morphml/tracking/experiment.py +309 -0
  134. morphml/tracking/logger.py +301 -0
  135. morphml/tracking/reporter.py +357 -0
  136. morphml/utils/__init__.py +6 -0
  137. morphml/utils/checkpoint.py +189 -0
  138. morphml/utils/comparison.py +390 -0
  139. morphml/utils/export.py +407 -0
  140. morphml/utils/progress.py +392 -0
  141. morphml/utils/validation.py +392 -0
  142. morphml/version.py +7 -0
  143. morphml/visualization/__init__.py +50 -0
  144. morphml/visualization/analytics.py +423 -0
  145. morphml/visualization/architecture_diagrams.py +353 -0
  146. morphml/visualization/architecture_plot.py +223 -0
  147. morphml/visualization/convergence_plot.py +174 -0
  148. morphml/visualization/crossover_viz.py +386 -0
  149. morphml/visualization/graph_viz.py +338 -0
  150. morphml/visualization/pareto_plot.py +149 -0
  151. morphml/visualization/plotly_dashboards.py +422 -0
  152. morphml/visualization/population.py +309 -0
  153. morphml/visualization/progress.py +260 -0
  154. morphml-1.0.0.dist-info/METADATA +434 -0
  155. morphml-1.0.0.dist-info/RECORD +158 -0
  156. morphml-1.0.0.dist-info/WHEEL +4 -0
  157. morphml-1.0.0.dist-info/entry_points.txt +3 -0
  158. morphml-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,530 @@
1
+ """PyTorch adapter for MorphML.
2
+
3
+ Converts ModelGraph to PyTorch nn.Module with full training support.
4
+
5
+ Example:
6
+ >>> from morphml.integrations import PyTorchAdapter
7
+ >>> adapter = PyTorchAdapter()
8
+ >>> model = adapter.build_model(graph)
9
+ >>> trainer = adapter.get_trainer(model, config={'learning_rate': 1e-3})
10
+ >>> results = trainer.train(train_loader, val_loader, num_epochs=50)
11
+ """
12
+
13
+ from typing import Any, Dict, Optional, Tuple
14
+
15
+ import numpy as np
16
+
17
+ try:
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.optim as optim
21
+ from torch.utils.data import DataLoader
22
+
23
+ TORCH_AVAILABLE = True
24
+ except ImportError:
25
+ TORCH_AVAILABLE = False
26
+ torch = None
27
+ nn = None
28
+
29
+ from morphml.core.graph import GraphNode, ModelGraph
30
+ from morphml.logging_config import get_logger
31
+
32
+ logger = get_logger(__name__)
33
+
34
+
35
+ class PyTorchAdapter:
36
+ """
37
+ Convert ModelGraph to PyTorch nn.Module.
38
+
39
+ Provides full integration with PyTorch including:
40
+ - Model building from graph
41
+ - Automatic shape inference
42
+ - Training support
43
+ - GPU acceleration
44
+
45
+ Example:
46
+ >>> adapter = PyTorchAdapter()
47
+ >>> model = adapter.build_model(graph)
48
+ >>> model.train()
49
+ >>> output = model(torch.randn(1, 3, 32, 32))
50
+ """
51
+
52
+ def __init__(self):
53
+ """Initialize PyTorch adapter."""
54
+ if not TORCH_AVAILABLE:
55
+ raise ImportError(
56
+ "PyTorch is required for PyTorchAdapter. " "Install with: pip install torch"
57
+ )
58
+ logger.info("Initialized PyTorchAdapter")
59
+
60
+ def build_model(
61
+ self, graph: ModelGraph, input_shape: Optional[Tuple[int, ...]] = None
62
+ ) -> nn.Module:
63
+ """
64
+ Build PyTorch model from graph.
65
+
66
+ Args:
67
+ graph: ModelGraph to convert
68
+ input_shape: Optional input shape (C, H, W)
69
+
70
+ Returns:
71
+ nn.Module instance
72
+
73
+ Example:
74
+ >>> model = adapter.build_model(graph, input_shape=(3, 32, 32))
75
+ """
76
+ return GraphToModule(graph, input_shape)
77
+
78
+ def get_trainer(
79
+ self, model: nn.Module, config: Optional[Dict[str, Any]] = None
80
+ ) -> "PyTorchTrainer":
81
+ """
82
+ Get trainer for model.
83
+
84
+ Args:
85
+ model: PyTorch model
86
+ config: Training configuration
87
+
88
+ Returns:
89
+ PyTorchTrainer instance
90
+
91
+ Example:
92
+ >>> trainer = adapter.get_trainer(model, {
93
+ ... 'learning_rate': 1e-3,
94
+ ... 'weight_decay': 1e-4
95
+ ... })
96
+ """
97
+ return PyTorchTrainer(model, config or {})
98
+
99
+
100
+ class GraphToModule(nn.Module):
101
+ """
102
+ PyTorch module generated from ModelGraph.
103
+
104
+ Dynamically creates layers based on graph structure and handles
105
+ forward pass following graph topology.
106
+
107
+ Attributes:
108
+ graph: Source ModelGraph
109
+ layers: ModuleDict of created layers
110
+ input_shape: Expected input shape
111
+ """
112
+
113
+ def __init__(self, graph: ModelGraph, input_shape: Optional[Tuple[int, ...]] = None):
114
+ """
115
+ Initialize module from graph.
116
+
117
+ Args:
118
+ graph: ModelGraph to convert
119
+ input_shape: Optional input shape for inference
120
+ """
121
+ super().__init__()
122
+
123
+ self.graph = graph
124
+ self.input_shape = input_shape or (3, 32, 32)
125
+ self.layers = nn.ModuleDict()
126
+
127
+ # Infer shapes
128
+ self.shapes = self._infer_shapes()
129
+
130
+ # Build layers
131
+ for node_id, node in graph.nodes.items():
132
+ layer = self._create_layer(node)
133
+ if layer is not None:
134
+ self.layers[str(node_id)] = layer
135
+
136
+ logger.info(f"Created PyTorch model with {len(self.layers)} layers")
137
+
138
+ def _infer_shapes(self) -> Dict[str, Tuple[int, ...]]:
139
+ """Infer shapes for all nodes."""
140
+ shapes = {}
141
+
142
+ for node in self.graph.topological_sort():
143
+ if node.operation == "input":
144
+ shapes[node.id] = self.input_shape
145
+ else:
146
+ shapes[node.id] = self._infer_node_shape(node, shapes)
147
+
148
+ return shapes
149
+
150
+ def _infer_node_shape(
151
+ self, node: GraphNode, shapes: Dict[str, Tuple[int, ...]]
152
+ ) -> Tuple[int, ...]:
153
+ """Infer shape for a single node."""
154
+ if not node.predecessors:
155
+ return self.input_shape
156
+
157
+ pred_shape = shapes[list(node.predecessors)[0].id]
158
+ op = node.operation
159
+ params = node.params
160
+
161
+ if op == "conv2d":
162
+ C, H, W = pred_shape
163
+ filters = params.get("filters", 64)
164
+ kernel = params.get("kernel_size", 3)
165
+ stride = params.get("stride", 1)
166
+ padding = params.get("padding", 1)
167
+
168
+ H_out = (H + 2 * padding - kernel) // stride + 1
169
+ W_out = (W + 2 * padding - kernel) // stride + 1
170
+ return (filters, H_out, W_out)
171
+
172
+ elif op in ["maxpool", "avgpool"]:
173
+ C, H, W = pred_shape
174
+ pool_size = params.get("pool_size", 2)
175
+ stride = params.get("stride", pool_size)
176
+ return (C, H // stride, W // stride)
177
+
178
+ elif op == "flatten":
179
+ return (int(np.prod(pred_shape)),)
180
+
181
+ elif op == "dense":
182
+ return (params.get("units", 10),)
183
+
184
+ else:
185
+ return pred_shape
186
+
187
+ def _create_layer(self, node: GraphNode) -> Optional[nn.Module]:
188
+ """
189
+ Create PyTorch layer from node.
190
+
191
+ Args:
192
+ node: GraphNode to convert
193
+
194
+ Returns:
195
+ nn.Module or None for functional operations
196
+ """
197
+ op = node.operation
198
+ params = node.params
199
+ self.shapes.get(node.id)
200
+
201
+ if op == "input":
202
+ return None # No layer needed
203
+
204
+ elif op == "conv2d":
205
+ # Get input channels from predecessor
206
+ if node.predecessors:
207
+ pred_shape = self.shapes[list(node.predecessors)[0].id]
208
+ in_channels = pred_shape[0]
209
+ else:
210
+ in_channels = params.get("in_channels", 3)
211
+
212
+ return nn.Conv2d(
213
+ in_channels=in_channels,
214
+ out_channels=params.get("filters", 64),
215
+ kernel_size=params.get("kernel_size", 3),
216
+ stride=params.get("stride", 1),
217
+ padding=params.get("padding", 1),
218
+ )
219
+
220
+ elif op == "maxpool":
221
+ return nn.MaxPool2d(
222
+ kernel_size=params.get("pool_size", 2),
223
+ stride=params.get("stride", params.get("pool_size", 2)),
224
+ )
225
+
226
+ elif op == "avgpool":
227
+ return nn.AvgPool2d(
228
+ kernel_size=params.get("pool_size", 2),
229
+ stride=params.get("stride", params.get("pool_size", 2)),
230
+ )
231
+
232
+ elif op == "dense":
233
+ # Get input features from predecessor
234
+ if node.predecessors:
235
+ pred_shape = self.shapes[list(node.predecessors)[0].id]
236
+ in_features = int(np.prod(pred_shape))
237
+ else:
238
+ in_features = params.get("in_features", 512)
239
+
240
+ return nn.Linear(in_features=in_features, out_features=params.get("units", 10))
241
+
242
+ elif op == "relu":
243
+ return nn.ReLU()
244
+
245
+ elif op == "sigmoid":
246
+ return nn.Sigmoid()
247
+
248
+ elif op == "tanh":
249
+ return nn.Tanh()
250
+
251
+ elif op == "softmax":
252
+ return nn.Softmax(dim=1)
253
+
254
+ elif op == "batchnorm":
255
+ # Infer num_features from predecessor
256
+ if node.predecessors:
257
+ pred_shape = self.shapes[list(node.predecessors)[0].id]
258
+ if len(pred_shape) == 3: # (C, H, W)
259
+ return nn.BatchNorm2d(pred_shape[0])
260
+ else: # (features,)
261
+ return nn.BatchNorm1d(pred_shape[0])
262
+ return nn.Identity()
263
+
264
+ elif op == "dropout":
265
+ return nn.Dropout(p=params.get("rate", 0.5))
266
+
267
+ elif op == "flatten":
268
+ return nn.Flatten()
269
+
270
+ else:
271
+ logger.warning(f"Unknown operation: {op}, using Identity")
272
+ return nn.Identity()
273
+
274
+ def forward(self, x: "torch.Tensor") -> "torch.Tensor":
275
+ """
276
+ Forward pass following graph topology.
277
+
278
+ Args:
279
+ x: Input tensor
280
+
281
+ Returns:
282
+ Output tensor
283
+ """
284
+ # Topological sort
285
+ topo_order = self.graph.topological_sort()
286
+
287
+ # Track outputs
288
+ outputs = {}
289
+
290
+ for node in topo_order:
291
+ # Get layer
292
+ layer = self.layers.get(str(node.id))
293
+
294
+ # Get input
295
+ if not node.predecessors:
296
+ # Input node
297
+ node_input = x
298
+ else:
299
+ # Combine predecessor outputs
300
+ pred_outputs = [outputs[pred.id] for pred in node.predecessors]
301
+
302
+ if len(pred_outputs) == 1:
303
+ node_input = pred_outputs[0]
304
+ else:
305
+ # Concatenate along channel dimension
306
+ node_input = torch.cat(pred_outputs, dim=1)
307
+
308
+ # Apply layer
309
+ if layer is not None:
310
+ outputs[node.id] = layer(node_input)
311
+ else:
312
+ outputs[node.id] = node_input
313
+
314
+ # Return output node's output
315
+ output_nodes = [n for n in self.graph.nodes.values() if not n.successors]
316
+ if output_nodes:
317
+ return outputs[output_nodes[0].id]
318
+ else:
319
+ # Return last node's output
320
+ return outputs[topo_order[-1].id]
321
+
322
+
323
+ class PyTorchTrainer:
324
+ """
325
+ Trainer for PyTorch models.
326
+
327
+ Handles training loop, validation, logging, and checkpointing.
328
+
329
+ Attributes:
330
+ model: PyTorch model to train
331
+ config: Training configuration
332
+ device: Device (CPU/GPU)
333
+ optimizer: Optimizer instance
334
+ criterion: Loss function
335
+ """
336
+
337
+ def __init__(self, model: nn.Module, config: Dict[str, Any]):
338
+ """
339
+ Initialize trainer.
340
+
341
+ Args:
342
+ model: PyTorch model
343
+ config: Training configuration with keys:
344
+ - learning_rate: Learning rate (default: 1e-3)
345
+ - weight_decay: Weight decay (default: 0)
346
+ - optimizer: Optimizer name (default: 'adam')
347
+ - loss: Loss function name (default: 'cross_entropy')
348
+ """
349
+ self.model = model
350
+ self.config = config
351
+
352
+ # Device
353
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
354
+ self.model = self.model.to(self.device)
355
+
356
+ logger.info(f"Using device: {self.device}")
357
+
358
+ # Optimizer
359
+ optimizer_name = config.get("optimizer", "adam").lower()
360
+ lr = config.get("learning_rate", 1e-3)
361
+ weight_decay = config.get("weight_decay", 0)
362
+
363
+ if optimizer_name == "adam":
364
+ self.optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
365
+ elif optimizer_name == "sgd":
366
+ self.optimizer = optim.SGD(
367
+ model.parameters(),
368
+ lr=lr,
369
+ momentum=config.get("momentum", 0.9),
370
+ weight_decay=weight_decay,
371
+ )
372
+ else:
373
+ self.optimizer = optim.Adam(model.parameters(), lr=lr)
374
+
375
+ # Loss function
376
+ loss_name = config.get("loss", "cross_entropy").lower()
377
+ if loss_name == "cross_entropy":
378
+ self.criterion = nn.CrossEntropyLoss()
379
+ elif loss_name == "mse":
380
+ self.criterion = nn.MSELoss()
381
+ elif loss_name == "bce":
382
+ self.criterion = nn.BCEWithLogitsLoss()
383
+ else:
384
+ self.criterion = nn.CrossEntropyLoss()
385
+
386
+ # Learning rate scheduler
387
+ if config.get("use_scheduler", False):
388
+ self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
389
+ self.optimizer, mode="max", factor=0.5, patience=5
390
+ )
391
+ else:
392
+ self.scheduler = None
393
+
394
+ def train(
395
+ self,
396
+ train_loader: DataLoader,
397
+ val_loader: Optional[DataLoader] = None,
398
+ num_epochs: int = 50,
399
+ ) -> Dict[str, float]:
400
+ """
401
+ Train model.
402
+
403
+ Args:
404
+ train_loader: Training data loader
405
+ val_loader: Validation data loader (optional)
406
+ num_epochs: Number of epochs
407
+
408
+ Returns:
409
+ Training results dictionary with:
410
+ - best_val_accuracy: Best validation accuracy
411
+ - final_train_accuracy: Final training accuracy
412
+ - final_val_accuracy: Final validation accuracy
413
+ """
414
+ best_val_acc = 0.0
415
+ train_acc = 0.0
416
+ val_acc = 0.0
417
+
418
+ for epoch in range(num_epochs):
419
+ # Train
420
+ train_loss, train_acc = self._train_epoch(train_loader)
421
+
422
+ # Validate
423
+ if val_loader is not None:
424
+ val_loss, val_acc = self._validate(val_loader)
425
+
426
+ if val_acc > best_val_acc:
427
+ best_val_acc = val_acc
428
+
429
+ # Learning rate scheduling
430
+ if self.scheduler is not None:
431
+ self.scheduler.step(val_acc)
432
+
433
+ if (epoch + 1) % 10 == 0:
434
+ logger.info(
435
+ f"Epoch {epoch+1}/{num_epochs}: "
436
+ f"train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, "
437
+ f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}"
438
+ )
439
+ else:
440
+ if (epoch + 1) % 10 == 0:
441
+ logger.info(
442
+ f"Epoch {epoch+1}/{num_epochs}: "
443
+ f"train_loss={train_loss:.4f}, train_acc={train_acc:.4f}"
444
+ )
445
+
446
+ return {
447
+ "best_val_accuracy": best_val_acc,
448
+ "final_train_accuracy": train_acc,
449
+ "final_val_accuracy": val_acc,
450
+ }
451
+
452
+ def _train_epoch(self, loader: DataLoader) -> Tuple[float, float]:
453
+ """
454
+ Single training epoch.
455
+
456
+ Args:
457
+ loader: Data loader
458
+
459
+ Returns:
460
+ Tuple of (loss, accuracy)
461
+ """
462
+ self.model.train()
463
+
464
+ total_loss = 0.0
465
+ correct = 0
466
+ total = 0
467
+
468
+ for X, y in loader:
469
+ X, y = X.to(self.device), y.to(self.device)
470
+
471
+ # Forward
472
+ logits = self.model(X)
473
+ loss = self.criterion(logits, y)
474
+
475
+ # Backward
476
+ self.optimizer.zero_grad()
477
+ loss.backward()
478
+ self.optimizer.step()
479
+
480
+ # Metrics
481
+ total_loss += loss.item()
482
+ pred = logits.argmax(1)
483
+ correct += (pred == y).sum().item()
484
+ total += y.size(0)
485
+
486
+ return total_loss / len(loader), correct / total
487
+
488
+ def _validate(self, loader: DataLoader) -> Tuple[float, float]:
489
+ """
490
+ Validation.
491
+
492
+ Args:
493
+ loader: Data loader
494
+
495
+ Returns:
496
+ Tuple of (loss, accuracy)
497
+ """
498
+ self.model.eval()
499
+
500
+ total_loss = 0.0
501
+ correct = 0
502
+ total = 0
503
+
504
+ with torch.no_grad():
505
+ for X, y in loader:
506
+ X, y = X.to(self.device), y.to(self.device)
507
+
508
+ logits = self.model(X)
509
+ loss = self.criterion(logits, y)
510
+
511
+ total_loss += loss.item()
512
+ pred = logits.argmax(1)
513
+ correct += (pred == y).sum().item()
514
+ total += y.size(0)
515
+
516
+ return total_loss / len(loader), correct / total
517
+
518
+ def evaluate(self, test_loader: DataLoader) -> Dict[str, float]:
519
+ """
520
+ Evaluate model on test set.
521
+
522
+ Args:
523
+ test_loader: Test data loader
524
+
525
+ Returns:
526
+ Evaluation metrics
527
+ """
528
+ test_loss, test_acc = self._validate(test_loader)
529
+
530
+ return {"test_loss": test_loss, "test_accuracy": test_acc}