morphml 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of morphml might be problematic. Click here for more details.
- morphml/__init__.py +14 -0
- morphml/api/__init__.py +26 -0
- morphml/api/app.py +326 -0
- morphml/api/auth.py +193 -0
- morphml/api/client.py +338 -0
- morphml/api/models.py +132 -0
- morphml/api/rate_limit.py +192 -0
- morphml/benchmarking/__init__.py +36 -0
- morphml/benchmarking/comparison.py +430 -0
- morphml/benchmarks/__init__.py +56 -0
- morphml/benchmarks/comparator.py +409 -0
- morphml/benchmarks/datasets.py +280 -0
- morphml/benchmarks/metrics.py +199 -0
- morphml/benchmarks/openml_suite.py +201 -0
- morphml/benchmarks/problems.py +289 -0
- morphml/benchmarks/suite.py +318 -0
- morphml/cli/__init__.py +5 -0
- morphml/cli/commands/experiment.py +329 -0
- morphml/cli/main.py +457 -0
- morphml/cli/quickstart.py +312 -0
- morphml/config.py +278 -0
- morphml/constraints/__init__.py +19 -0
- morphml/constraints/handler.py +205 -0
- morphml/constraints/predicates.py +285 -0
- morphml/core/__init__.py +3 -0
- morphml/core/crossover.py +449 -0
- morphml/core/dsl/README.md +359 -0
- morphml/core/dsl/__init__.py +72 -0
- morphml/core/dsl/ast_nodes.py +364 -0
- morphml/core/dsl/compiler.py +318 -0
- morphml/core/dsl/layers.py +368 -0
- morphml/core/dsl/lexer.py +336 -0
- morphml/core/dsl/parser.py +455 -0
- morphml/core/dsl/search_space.py +386 -0
- morphml/core/dsl/syntax.py +199 -0
- morphml/core/dsl/type_system.py +361 -0
- morphml/core/dsl/validator.py +386 -0
- morphml/core/graph/__init__.py +40 -0
- morphml/core/graph/edge.py +124 -0
- morphml/core/graph/graph.py +507 -0
- morphml/core/graph/mutations.py +409 -0
- morphml/core/graph/node.py +196 -0
- morphml/core/graph/serialization.py +361 -0
- morphml/core/graph/visualization.py +431 -0
- morphml/core/objectives/__init__.py +20 -0
- morphml/core/search/__init__.py +33 -0
- morphml/core/search/individual.py +252 -0
- morphml/core/search/parameters.py +453 -0
- morphml/core/search/population.py +375 -0
- morphml/core/search/search_engine.py +340 -0
- morphml/distributed/__init__.py +76 -0
- morphml/distributed/fault_tolerance.py +497 -0
- morphml/distributed/health_monitor.py +348 -0
- morphml/distributed/master.py +709 -0
- morphml/distributed/proto/README.md +224 -0
- morphml/distributed/proto/__init__.py +74 -0
- morphml/distributed/proto/worker.proto +170 -0
- morphml/distributed/proto/worker_pb2.py +79 -0
- morphml/distributed/proto/worker_pb2_grpc.py +423 -0
- morphml/distributed/resource_manager.py +416 -0
- morphml/distributed/scheduler.py +567 -0
- morphml/distributed/storage/__init__.py +33 -0
- morphml/distributed/storage/artifacts.py +381 -0
- morphml/distributed/storage/cache.py +366 -0
- morphml/distributed/storage/checkpointing.py +329 -0
- morphml/distributed/storage/database.py +459 -0
- morphml/distributed/worker.py +549 -0
- morphml/evaluation/__init__.py +5 -0
- morphml/evaluation/heuristic.py +237 -0
- morphml/exceptions.py +55 -0
- morphml/execution/__init__.py +5 -0
- morphml/execution/local_executor.py +350 -0
- morphml/integrations/__init__.py +28 -0
- morphml/integrations/jax_adapter.py +206 -0
- morphml/integrations/pytorch_adapter.py +530 -0
- morphml/integrations/sklearn_adapter.py +206 -0
- morphml/integrations/tensorflow_adapter.py +230 -0
- morphml/logging_config.py +93 -0
- morphml/meta_learning/__init__.py +66 -0
- morphml/meta_learning/architecture_similarity.py +277 -0
- morphml/meta_learning/experiment_database.py +240 -0
- morphml/meta_learning/knowledge_base/__init__.py +19 -0
- morphml/meta_learning/knowledge_base/embedder.py +179 -0
- morphml/meta_learning/knowledge_base/knowledge_base.py +313 -0
- morphml/meta_learning/knowledge_base/meta_features.py +265 -0
- morphml/meta_learning/knowledge_base/vector_store.py +271 -0
- morphml/meta_learning/predictors/__init__.py +27 -0
- morphml/meta_learning/predictors/ensemble.py +221 -0
- morphml/meta_learning/predictors/gnn_predictor.py +552 -0
- morphml/meta_learning/predictors/learning_curve.py +231 -0
- morphml/meta_learning/predictors/proxy_metrics.py +261 -0
- morphml/meta_learning/strategy_evolution/__init__.py +27 -0
- morphml/meta_learning/strategy_evolution/adaptive_optimizer.py +226 -0
- morphml/meta_learning/strategy_evolution/bandit.py +276 -0
- morphml/meta_learning/strategy_evolution/portfolio.py +230 -0
- morphml/meta_learning/transfer.py +581 -0
- morphml/meta_learning/warm_start.py +286 -0
- morphml/optimizers/__init__.py +74 -0
- morphml/optimizers/adaptive_operators.py +399 -0
- morphml/optimizers/bayesian/__init__.py +52 -0
- morphml/optimizers/bayesian/acquisition.py +387 -0
- morphml/optimizers/bayesian/base.py +319 -0
- morphml/optimizers/bayesian/gaussian_process.py +635 -0
- morphml/optimizers/bayesian/smac.py +534 -0
- morphml/optimizers/bayesian/tpe.py +411 -0
- morphml/optimizers/differential_evolution.py +220 -0
- morphml/optimizers/evolutionary/__init__.py +61 -0
- morphml/optimizers/evolutionary/cma_es.py +416 -0
- morphml/optimizers/evolutionary/differential_evolution.py +556 -0
- morphml/optimizers/evolutionary/encoding.py +426 -0
- morphml/optimizers/evolutionary/particle_swarm.py +449 -0
- morphml/optimizers/genetic_algorithm.py +486 -0
- morphml/optimizers/gradient_based/__init__.py +22 -0
- morphml/optimizers/gradient_based/darts.py +550 -0
- morphml/optimizers/gradient_based/enas.py +585 -0
- morphml/optimizers/gradient_based/operations.py +474 -0
- morphml/optimizers/gradient_based/utils.py +601 -0
- morphml/optimizers/hill_climbing.py +169 -0
- morphml/optimizers/multi_objective/__init__.py +56 -0
- morphml/optimizers/multi_objective/indicators.py +504 -0
- morphml/optimizers/multi_objective/nsga2.py +647 -0
- morphml/optimizers/multi_objective/visualization.py +427 -0
- morphml/optimizers/nsga2.py +308 -0
- morphml/optimizers/random_search.py +172 -0
- morphml/optimizers/simulated_annealing.py +181 -0
- morphml/plugins/__init__.py +35 -0
- morphml/plugins/custom_evaluator_example.py +81 -0
- morphml/plugins/custom_optimizer_example.py +63 -0
- morphml/plugins/plugin_system.py +454 -0
- morphml/reports/__init__.py +30 -0
- morphml/reports/generator.py +362 -0
- morphml/tracking/__init__.py +7 -0
- morphml/tracking/experiment.py +309 -0
- morphml/tracking/logger.py +301 -0
- morphml/tracking/reporter.py +357 -0
- morphml/utils/__init__.py +6 -0
- morphml/utils/checkpoint.py +189 -0
- morphml/utils/comparison.py +390 -0
- morphml/utils/export.py +407 -0
- morphml/utils/progress.py +392 -0
- morphml/utils/validation.py +392 -0
- morphml/version.py +7 -0
- morphml/visualization/__init__.py +50 -0
- morphml/visualization/analytics.py +423 -0
- morphml/visualization/architecture_diagrams.py +353 -0
- morphml/visualization/architecture_plot.py +223 -0
- morphml/visualization/convergence_plot.py +174 -0
- morphml/visualization/crossover_viz.py +386 -0
- morphml/visualization/graph_viz.py +338 -0
- morphml/visualization/pareto_plot.py +149 -0
- morphml/visualization/plotly_dashboards.py +422 -0
- morphml/visualization/population.py +309 -0
- morphml/visualization/progress.py +260 -0
- morphml-1.0.0.dist-info/METADATA +434 -0
- morphml-1.0.0.dist-info/RECORD +158 -0
- morphml-1.0.0.dist-info/WHEEL +4 -0
- morphml-1.0.0.dist-info/entry_points.txt +3 -0
- morphml-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,530 @@
|
|
|
1
|
+
"""PyTorch adapter for MorphML.
|
|
2
|
+
|
|
3
|
+
Converts ModelGraph to PyTorch nn.Module with full training support.
|
|
4
|
+
|
|
5
|
+
Example:
|
|
6
|
+
>>> from morphml.integrations import PyTorchAdapter
|
|
7
|
+
>>> adapter = PyTorchAdapter()
|
|
8
|
+
>>> model = adapter.build_model(graph)
|
|
9
|
+
>>> trainer = adapter.get_trainer(model, config={'learning_rate': 1e-3})
|
|
10
|
+
>>> results = trainer.train(train_loader, val_loader, num_epochs=50)
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from typing import Any, Dict, Optional, Tuple
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn as nn
|
|
20
|
+
import torch.optim as optim
|
|
21
|
+
from torch.utils.data import DataLoader
|
|
22
|
+
|
|
23
|
+
TORCH_AVAILABLE = True
|
|
24
|
+
except ImportError:
|
|
25
|
+
TORCH_AVAILABLE = False
|
|
26
|
+
torch = None
|
|
27
|
+
nn = None
|
|
28
|
+
|
|
29
|
+
from morphml.core.graph import GraphNode, ModelGraph
|
|
30
|
+
from morphml.logging_config import get_logger
|
|
31
|
+
|
|
32
|
+
logger = get_logger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class PyTorchAdapter:
|
|
36
|
+
"""
|
|
37
|
+
Convert ModelGraph to PyTorch nn.Module.
|
|
38
|
+
|
|
39
|
+
Provides full integration with PyTorch including:
|
|
40
|
+
- Model building from graph
|
|
41
|
+
- Automatic shape inference
|
|
42
|
+
- Training support
|
|
43
|
+
- GPU acceleration
|
|
44
|
+
|
|
45
|
+
Example:
|
|
46
|
+
>>> adapter = PyTorchAdapter()
|
|
47
|
+
>>> model = adapter.build_model(graph)
|
|
48
|
+
>>> model.train()
|
|
49
|
+
>>> output = model(torch.randn(1, 3, 32, 32))
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(self):
|
|
53
|
+
"""Initialize PyTorch adapter."""
|
|
54
|
+
if not TORCH_AVAILABLE:
|
|
55
|
+
raise ImportError(
|
|
56
|
+
"PyTorch is required for PyTorchAdapter. " "Install with: pip install torch"
|
|
57
|
+
)
|
|
58
|
+
logger.info("Initialized PyTorchAdapter")
|
|
59
|
+
|
|
60
|
+
def build_model(
|
|
61
|
+
self, graph: ModelGraph, input_shape: Optional[Tuple[int, ...]] = None
|
|
62
|
+
) -> nn.Module:
|
|
63
|
+
"""
|
|
64
|
+
Build PyTorch model from graph.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
graph: ModelGraph to convert
|
|
68
|
+
input_shape: Optional input shape (C, H, W)
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
nn.Module instance
|
|
72
|
+
|
|
73
|
+
Example:
|
|
74
|
+
>>> model = adapter.build_model(graph, input_shape=(3, 32, 32))
|
|
75
|
+
"""
|
|
76
|
+
return GraphToModule(graph, input_shape)
|
|
77
|
+
|
|
78
|
+
def get_trainer(
|
|
79
|
+
self, model: nn.Module, config: Optional[Dict[str, Any]] = None
|
|
80
|
+
) -> "PyTorchTrainer":
|
|
81
|
+
"""
|
|
82
|
+
Get trainer for model.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
model: PyTorch model
|
|
86
|
+
config: Training configuration
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
PyTorchTrainer instance
|
|
90
|
+
|
|
91
|
+
Example:
|
|
92
|
+
>>> trainer = adapter.get_trainer(model, {
|
|
93
|
+
... 'learning_rate': 1e-3,
|
|
94
|
+
... 'weight_decay': 1e-4
|
|
95
|
+
... })
|
|
96
|
+
"""
|
|
97
|
+
return PyTorchTrainer(model, config or {})
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class GraphToModule(nn.Module):
|
|
101
|
+
"""
|
|
102
|
+
PyTorch module generated from ModelGraph.
|
|
103
|
+
|
|
104
|
+
Dynamically creates layers based on graph structure and handles
|
|
105
|
+
forward pass following graph topology.
|
|
106
|
+
|
|
107
|
+
Attributes:
|
|
108
|
+
graph: Source ModelGraph
|
|
109
|
+
layers: ModuleDict of created layers
|
|
110
|
+
input_shape: Expected input shape
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
def __init__(self, graph: ModelGraph, input_shape: Optional[Tuple[int, ...]] = None):
|
|
114
|
+
"""
|
|
115
|
+
Initialize module from graph.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
graph: ModelGraph to convert
|
|
119
|
+
input_shape: Optional input shape for inference
|
|
120
|
+
"""
|
|
121
|
+
super().__init__()
|
|
122
|
+
|
|
123
|
+
self.graph = graph
|
|
124
|
+
self.input_shape = input_shape or (3, 32, 32)
|
|
125
|
+
self.layers = nn.ModuleDict()
|
|
126
|
+
|
|
127
|
+
# Infer shapes
|
|
128
|
+
self.shapes = self._infer_shapes()
|
|
129
|
+
|
|
130
|
+
# Build layers
|
|
131
|
+
for node_id, node in graph.nodes.items():
|
|
132
|
+
layer = self._create_layer(node)
|
|
133
|
+
if layer is not None:
|
|
134
|
+
self.layers[str(node_id)] = layer
|
|
135
|
+
|
|
136
|
+
logger.info(f"Created PyTorch model with {len(self.layers)} layers")
|
|
137
|
+
|
|
138
|
+
def _infer_shapes(self) -> Dict[str, Tuple[int, ...]]:
|
|
139
|
+
"""Infer shapes for all nodes."""
|
|
140
|
+
shapes = {}
|
|
141
|
+
|
|
142
|
+
for node in self.graph.topological_sort():
|
|
143
|
+
if node.operation == "input":
|
|
144
|
+
shapes[node.id] = self.input_shape
|
|
145
|
+
else:
|
|
146
|
+
shapes[node.id] = self._infer_node_shape(node, shapes)
|
|
147
|
+
|
|
148
|
+
return shapes
|
|
149
|
+
|
|
150
|
+
def _infer_node_shape(
|
|
151
|
+
self, node: GraphNode, shapes: Dict[str, Tuple[int, ...]]
|
|
152
|
+
) -> Tuple[int, ...]:
|
|
153
|
+
"""Infer shape for a single node."""
|
|
154
|
+
if not node.predecessors:
|
|
155
|
+
return self.input_shape
|
|
156
|
+
|
|
157
|
+
pred_shape = shapes[list(node.predecessors)[0].id]
|
|
158
|
+
op = node.operation
|
|
159
|
+
params = node.params
|
|
160
|
+
|
|
161
|
+
if op == "conv2d":
|
|
162
|
+
C, H, W = pred_shape
|
|
163
|
+
filters = params.get("filters", 64)
|
|
164
|
+
kernel = params.get("kernel_size", 3)
|
|
165
|
+
stride = params.get("stride", 1)
|
|
166
|
+
padding = params.get("padding", 1)
|
|
167
|
+
|
|
168
|
+
H_out = (H + 2 * padding - kernel) // stride + 1
|
|
169
|
+
W_out = (W + 2 * padding - kernel) // stride + 1
|
|
170
|
+
return (filters, H_out, W_out)
|
|
171
|
+
|
|
172
|
+
elif op in ["maxpool", "avgpool"]:
|
|
173
|
+
C, H, W = pred_shape
|
|
174
|
+
pool_size = params.get("pool_size", 2)
|
|
175
|
+
stride = params.get("stride", pool_size)
|
|
176
|
+
return (C, H // stride, W // stride)
|
|
177
|
+
|
|
178
|
+
elif op == "flatten":
|
|
179
|
+
return (int(np.prod(pred_shape)),)
|
|
180
|
+
|
|
181
|
+
elif op == "dense":
|
|
182
|
+
return (params.get("units", 10),)
|
|
183
|
+
|
|
184
|
+
else:
|
|
185
|
+
return pred_shape
|
|
186
|
+
|
|
187
|
+
def _create_layer(self, node: GraphNode) -> Optional[nn.Module]:
|
|
188
|
+
"""
|
|
189
|
+
Create PyTorch layer from node.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
node: GraphNode to convert
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
nn.Module or None for functional operations
|
|
196
|
+
"""
|
|
197
|
+
op = node.operation
|
|
198
|
+
params = node.params
|
|
199
|
+
self.shapes.get(node.id)
|
|
200
|
+
|
|
201
|
+
if op == "input":
|
|
202
|
+
return None # No layer needed
|
|
203
|
+
|
|
204
|
+
elif op == "conv2d":
|
|
205
|
+
# Get input channels from predecessor
|
|
206
|
+
if node.predecessors:
|
|
207
|
+
pred_shape = self.shapes[list(node.predecessors)[0].id]
|
|
208
|
+
in_channels = pred_shape[0]
|
|
209
|
+
else:
|
|
210
|
+
in_channels = params.get("in_channels", 3)
|
|
211
|
+
|
|
212
|
+
return nn.Conv2d(
|
|
213
|
+
in_channels=in_channels,
|
|
214
|
+
out_channels=params.get("filters", 64),
|
|
215
|
+
kernel_size=params.get("kernel_size", 3),
|
|
216
|
+
stride=params.get("stride", 1),
|
|
217
|
+
padding=params.get("padding", 1),
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
elif op == "maxpool":
|
|
221
|
+
return nn.MaxPool2d(
|
|
222
|
+
kernel_size=params.get("pool_size", 2),
|
|
223
|
+
stride=params.get("stride", params.get("pool_size", 2)),
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
elif op == "avgpool":
|
|
227
|
+
return nn.AvgPool2d(
|
|
228
|
+
kernel_size=params.get("pool_size", 2),
|
|
229
|
+
stride=params.get("stride", params.get("pool_size", 2)),
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
elif op == "dense":
|
|
233
|
+
# Get input features from predecessor
|
|
234
|
+
if node.predecessors:
|
|
235
|
+
pred_shape = self.shapes[list(node.predecessors)[0].id]
|
|
236
|
+
in_features = int(np.prod(pred_shape))
|
|
237
|
+
else:
|
|
238
|
+
in_features = params.get("in_features", 512)
|
|
239
|
+
|
|
240
|
+
return nn.Linear(in_features=in_features, out_features=params.get("units", 10))
|
|
241
|
+
|
|
242
|
+
elif op == "relu":
|
|
243
|
+
return nn.ReLU()
|
|
244
|
+
|
|
245
|
+
elif op == "sigmoid":
|
|
246
|
+
return nn.Sigmoid()
|
|
247
|
+
|
|
248
|
+
elif op == "tanh":
|
|
249
|
+
return nn.Tanh()
|
|
250
|
+
|
|
251
|
+
elif op == "softmax":
|
|
252
|
+
return nn.Softmax(dim=1)
|
|
253
|
+
|
|
254
|
+
elif op == "batchnorm":
|
|
255
|
+
# Infer num_features from predecessor
|
|
256
|
+
if node.predecessors:
|
|
257
|
+
pred_shape = self.shapes[list(node.predecessors)[0].id]
|
|
258
|
+
if len(pred_shape) == 3: # (C, H, W)
|
|
259
|
+
return nn.BatchNorm2d(pred_shape[0])
|
|
260
|
+
else: # (features,)
|
|
261
|
+
return nn.BatchNorm1d(pred_shape[0])
|
|
262
|
+
return nn.Identity()
|
|
263
|
+
|
|
264
|
+
elif op == "dropout":
|
|
265
|
+
return nn.Dropout(p=params.get("rate", 0.5))
|
|
266
|
+
|
|
267
|
+
elif op == "flatten":
|
|
268
|
+
return nn.Flatten()
|
|
269
|
+
|
|
270
|
+
else:
|
|
271
|
+
logger.warning(f"Unknown operation: {op}, using Identity")
|
|
272
|
+
return nn.Identity()
|
|
273
|
+
|
|
274
|
+
def forward(self, x: "torch.Tensor") -> "torch.Tensor":
|
|
275
|
+
"""
|
|
276
|
+
Forward pass following graph topology.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
x: Input tensor
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
Output tensor
|
|
283
|
+
"""
|
|
284
|
+
# Topological sort
|
|
285
|
+
topo_order = self.graph.topological_sort()
|
|
286
|
+
|
|
287
|
+
# Track outputs
|
|
288
|
+
outputs = {}
|
|
289
|
+
|
|
290
|
+
for node in topo_order:
|
|
291
|
+
# Get layer
|
|
292
|
+
layer = self.layers.get(str(node.id))
|
|
293
|
+
|
|
294
|
+
# Get input
|
|
295
|
+
if not node.predecessors:
|
|
296
|
+
# Input node
|
|
297
|
+
node_input = x
|
|
298
|
+
else:
|
|
299
|
+
# Combine predecessor outputs
|
|
300
|
+
pred_outputs = [outputs[pred.id] for pred in node.predecessors]
|
|
301
|
+
|
|
302
|
+
if len(pred_outputs) == 1:
|
|
303
|
+
node_input = pred_outputs[0]
|
|
304
|
+
else:
|
|
305
|
+
# Concatenate along channel dimension
|
|
306
|
+
node_input = torch.cat(pred_outputs, dim=1)
|
|
307
|
+
|
|
308
|
+
# Apply layer
|
|
309
|
+
if layer is not None:
|
|
310
|
+
outputs[node.id] = layer(node_input)
|
|
311
|
+
else:
|
|
312
|
+
outputs[node.id] = node_input
|
|
313
|
+
|
|
314
|
+
# Return output node's output
|
|
315
|
+
output_nodes = [n for n in self.graph.nodes.values() if not n.successors]
|
|
316
|
+
if output_nodes:
|
|
317
|
+
return outputs[output_nodes[0].id]
|
|
318
|
+
else:
|
|
319
|
+
# Return last node's output
|
|
320
|
+
return outputs[topo_order[-1].id]
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
class PyTorchTrainer:
|
|
324
|
+
"""
|
|
325
|
+
Trainer for PyTorch models.
|
|
326
|
+
|
|
327
|
+
Handles training loop, validation, logging, and checkpointing.
|
|
328
|
+
|
|
329
|
+
Attributes:
|
|
330
|
+
model: PyTorch model to train
|
|
331
|
+
config: Training configuration
|
|
332
|
+
device: Device (CPU/GPU)
|
|
333
|
+
optimizer: Optimizer instance
|
|
334
|
+
criterion: Loss function
|
|
335
|
+
"""
|
|
336
|
+
|
|
337
|
+
def __init__(self, model: nn.Module, config: Dict[str, Any]):
|
|
338
|
+
"""
|
|
339
|
+
Initialize trainer.
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
model: PyTorch model
|
|
343
|
+
config: Training configuration with keys:
|
|
344
|
+
- learning_rate: Learning rate (default: 1e-3)
|
|
345
|
+
- weight_decay: Weight decay (default: 0)
|
|
346
|
+
- optimizer: Optimizer name (default: 'adam')
|
|
347
|
+
- loss: Loss function name (default: 'cross_entropy')
|
|
348
|
+
"""
|
|
349
|
+
self.model = model
|
|
350
|
+
self.config = config
|
|
351
|
+
|
|
352
|
+
# Device
|
|
353
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
354
|
+
self.model = self.model.to(self.device)
|
|
355
|
+
|
|
356
|
+
logger.info(f"Using device: {self.device}")
|
|
357
|
+
|
|
358
|
+
# Optimizer
|
|
359
|
+
optimizer_name = config.get("optimizer", "adam").lower()
|
|
360
|
+
lr = config.get("learning_rate", 1e-3)
|
|
361
|
+
weight_decay = config.get("weight_decay", 0)
|
|
362
|
+
|
|
363
|
+
if optimizer_name == "adam":
|
|
364
|
+
self.optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
|
365
|
+
elif optimizer_name == "sgd":
|
|
366
|
+
self.optimizer = optim.SGD(
|
|
367
|
+
model.parameters(),
|
|
368
|
+
lr=lr,
|
|
369
|
+
momentum=config.get("momentum", 0.9),
|
|
370
|
+
weight_decay=weight_decay,
|
|
371
|
+
)
|
|
372
|
+
else:
|
|
373
|
+
self.optimizer = optim.Adam(model.parameters(), lr=lr)
|
|
374
|
+
|
|
375
|
+
# Loss function
|
|
376
|
+
loss_name = config.get("loss", "cross_entropy").lower()
|
|
377
|
+
if loss_name == "cross_entropy":
|
|
378
|
+
self.criterion = nn.CrossEntropyLoss()
|
|
379
|
+
elif loss_name == "mse":
|
|
380
|
+
self.criterion = nn.MSELoss()
|
|
381
|
+
elif loss_name == "bce":
|
|
382
|
+
self.criterion = nn.BCEWithLogitsLoss()
|
|
383
|
+
else:
|
|
384
|
+
self.criterion = nn.CrossEntropyLoss()
|
|
385
|
+
|
|
386
|
+
# Learning rate scheduler
|
|
387
|
+
if config.get("use_scheduler", False):
|
|
388
|
+
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
|
389
|
+
self.optimizer, mode="max", factor=0.5, patience=5
|
|
390
|
+
)
|
|
391
|
+
else:
|
|
392
|
+
self.scheduler = None
|
|
393
|
+
|
|
394
|
+
def train(
|
|
395
|
+
self,
|
|
396
|
+
train_loader: DataLoader,
|
|
397
|
+
val_loader: Optional[DataLoader] = None,
|
|
398
|
+
num_epochs: int = 50,
|
|
399
|
+
) -> Dict[str, float]:
|
|
400
|
+
"""
|
|
401
|
+
Train model.
|
|
402
|
+
|
|
403
|
+
Args:
|
|
404
|
+
train_loader: Training data loader
|
|
405
|
+
val_loader: Validation data loader (optional)
|
|
406
|
+
num_epochs: Number of epochs
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
Training results dictionary with:
|
|
410
|
+
- best_val_accuracy: Best validation accuracy
|
|
411
|
+
- final_train_accuracy: Final training accuracy
|
|
412
|
+
- final_val_accuracy: Final validation accuracy
|
|
413
|
+
"""
|
|
414
|
+
best_val_acc = 0.0
|
|
415
|
+
train_acc = 0.0
|
|
416
|
+
val_acc = 0.0
|
|
417
|
+
|
|
418
|
+
for epoch in range(num_epochs):
|
|
419
|
+
# Train
|
|
420
|
+
train_loss, train_acc = self._train_epoch(train_loader)
|
|
421
|
+
|
|
422
|
+
# Validate
|
|
423
|
+
if val_loader is not None:
|
|
424
|
+
val_loss, val_acc = self._validate(val_loader)
|
|
425
|
+
|
|
426
|
+
if val_acc > best_val_acc:
|
|
427
|
+
best_val_acc = val_acc
|
|
428
|
+
|
|
429
|
+
# Learning rate scheduling
|
|
430
|
+
if self.scheduler is not None:
|
|
431
|
+
self.scheduler.step(val_acc)
|
|
432
|
+
|
|
433
|
+
if (epoch + 1) % 10 == 0:
|
|
434
|
+
logger.info(
|
|
435
|
+
f"Epoch {epoch+1}/{num_epochs}: "
|
|
436
|
+
f"train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, "
|
|
437
|
+
f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}"
|
|
438
|
+
)
|
|
439
|
+
else:
|
|
440
|
+
if (epoch + 1) % 10 == 0:
|
|
441
|
+
logger.info(
|
|
442
|
+
f"Epoch {epoch+1}/{num_epochs}: "
|
|
443
|
+
f"train_loss={train_loss:.4f}, train_acc={train_acc:.4f}"
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
return {
|
|
447
|
+
"best_val_accuracy": best_val_acc,
|
|
448
|
+
"final_train_accuracy": train_acc,
|
|
449
|
+
"final_val_accuracy": val_acc,
|
|
450
|
+
}
|
|
451
|
+
|
|
452
|
+
def _train_epoch(self, loader: DataLoader) -> Tuple[float, float]:
|
|
453
|
+
"""
|
|
454
|
+
Single training epoch.
|
|
455
|
+
|
|
456
|
+
Args:
|
|
457
|
+
loader: Data loader
|
|
458
|
+
|
|
459
|
+
Returns:
|
|
460
|
+
Tuple of (loss, accuracy)
|
|
461
|
+
"""
|
|
462
|
+
self.model.train()
|
|
463
|
+
|
|
464
|
+
total_loss = 0.0
|
|
465
|
+
correct = 0
|
|
466
|
+
total = 0
|
|
467
|
+
|
|
468
|
+
for X, y in loader:
|
|
469
|
+
X, y = X.to(self.device), y.to(self.device)
|
|
470
|
+
|
|
471
|
+
# Forward
|
|
472
|
+
logits = self.model(X)
|
|
473
|
+
loss = self.criterion(logits, y)
|
|
474
|
+
|
|
475
|
+
# Backward
|
|
476
|
+
self.optimizer.zero_grad()
|
|
477
|
+
loss.backward()
|
|
478
|
+
self.optimizer.step()
|
|
479
|
+
|
|
480
|
+
# Metrics
|
|
481
|
+
total_loss += loss.item()
|
|
482
|
+
pred = logits.argmax(1)
|
|
483
|
+
correct += (pred == y).sum().item()
|
|
484
|
+
total += y.size(0)
|
|
485
|
+
|
|
486
|
+
return total_loss / len(loader), correct / total
|
|
487
|
+
|
|
488
|
+
def _validate(self, loader: DataLoader) -> Tuple[float, float]:
|
|
489
|
+
"""
|
|
490
|
+
Validation.
|
|
491
|
+
|
|
492
|
+
Args:
|
|
493
|
+
loader: Data loader
|
|
494
|
+
|
|
495
|
+
Returns:
|
|
496
|
+
Tuple of (loss, accuracy)
|
|
497
|
+
"""
|
|
498
|
+
self.model.eval()
|
|
499
|
+
|
|
500
|
+
total_loss = 0.0
|
|
501
|
+
correct = 0
|
|
502
|
+
total = 0
|
|
503
|
+
|
|
504
|
+
with torch.no_grad():
|
|
505
|
+
for X, y in loader:
|
|
506
|
+
X, y = X.to(self.device), y.to(self.device)
|
|
507
|
+
|
|
508
|
+
logits = self.model(X)
|
|
509
|
+
loss = self.criterion(logits, y)
|
|
510
|
+
|
|
511
|
+
total_loss += loss.item()
|
|
512
|
+
pred = logits.argmax(1)
|
|
513
|
+
correct += (pred == y).sum().item()
|
|
514
|
+
total += y.size(0)
|
|
515
|
+
|
|
516
|
+
return total_loss / len(loader), correct / total
|
|
517
|
+
|
|
518
|
+
def evaluate(self, test_loader: DataLoader) -> Dict[str, float]:
|
|
519
|
+
"""
|
|
520
|
+
Evaluate model on test set.
|
|
521
|
+
|
|
522
|
+
Args:
|
|
523
|
+
test_loader: Test data loader
|
|
524
|
+
|
|
525
|
+
Returns:
|
|
526
|
+
Evaluation metrics
|
|
527
|
+
"""
|
|
528
|
+
test_loss, test_acc = self._validate(test_loader)
|
|
529
|
+
|
|
530
|
+
return {"test_loss": test_loss, "test_accuracy": test_acc}
|