cifar10-tools 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cifar10_tools/pytorch/__init__.py +33 -0
- cifar10_tools/pytorch/data.py +111 -3
- cifar10_tools/pytorch/hyperparameter_optimization.py +278 -0
- cifar10_tools/pytorch/plotting.py +77 -8
- cifar10_tools/pytorch/training.py +29 -4
- cifar10_tools-0.4.0.dist-info/METADATA +72 -0
- cifar10_tools-0.4.0.dist-info/RECORD +12 -0
- cifar10_tools-0.2.0.dist-info/METADATA +0 -35
- cifar10_tools-0.2.0.dist-info/RECORD +0 -11
- {cifar10_tools-0.2.0.dist-info → cifar10_tools-0.4.0.dist-info}/WHEEL +0 -0
- {cifar10_tools-0.2.0.dist-info → cifar10_tools-0.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
'''PyTorch utilities for CIFAR-10 classification.'''
|
|
2
|
+
|
|
3
|
+
from cifar10_tools.pytorch.data import download_cifar10_data
|
|
4
|
+
from cifar10_tools.pytorch.evaluation import evaluate_model
|
|
5
|
+
from cifar10_tools.pytorch.training import train_model
|
|
6
|
+
from cifar10_tools.pytorch.plotting import (
|
|
7
|
+
plot_sample_images,
|
|
8
|
+
plot_learning_curves,
|
|
9
|
+
plot_confusion_matrix,
|
|
10
|
+
plot_class_probability_distributions,
|
|
11
|
+
plot_evaluation_curves,
|
|
12
|
+
plot_optimization_results
|
|
13
|
+
)
|
|
14
|
+
from cifar10_tools.pytorch.hyperparameter_optimization import (
|
|
15
|
+
create_cnn,
|
|
16
|
+
train_trial,
|
|
17
|
+
create_objective
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
'download_cifar10_data',
|
|
22
|
+
'evaluate_model',
|
|
23
|
+
'train_model',
|
|
24
|
+
'plot_sample_images',
|
|
25
|
+
'plot_learning_curves',
|
|
26
|
+
'plot_confusion_matrix',
|
|
27
|
+
'plot_class_probability_distributions',
|
|
28
|
+
'plot_evaluation_curves',
|
|
29
|
+
'plot_optimization_results',
|
|
30
|
+
'create_cnn',
|
|
31
|
+
'train_trial',
|
|
32
|
+
'create_objective'
|
|
33
|
+
]
|
cifar10_tools/pytorch/data.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
|
-
'''Data
|
|
2
|
-
during devcontainer creation'''
|
|
1
|
+
'''Data loading and preprocessing functions for CIFAR-10 dataset.'''
|
|
3
2
|
|
|
4
3
|
from pathlib import Path
|
|
5
|
-
|
|
4
|
+
import torch
|
|
5
|
+
from torchvision import datasets, transforms
|
|
6
|
+
from torch.utils.data import DataLoader
|
|
7
|
+
|
|
6
8
|
|
|
7
9
|
def download_cifar10_data(data_dir: str='data/pytorch/cifar10'):
|
|
8
10
|
'''Download CIFAR-10 dataset using torchvision.datasets.'''
|
|
@@ -22,6 +24,112 @@ def download_cifar10_data(data_dir: str='data/pytorch/cifar10'):
|
|
|
22
24
|
download=True
|
|
23
25
|
)
|
|
24
26
|
|
|
27
|
+
|
|
28
|
+
def make_data_loaders(
|
|
29
|
+
data_dir: Path,
|
|
30
|
+
batch_size: int,
|
|
31
|
+
train_transform: transforms.Compose,
|
|
32
|
+
eval_transform: transforms.Compose,
|
|
33
|
+
device: torch.device | None = None,
|
|
34
|
+
download: bool = False,
|
|
35
|
+
):
|
|
36
|
+
"""
|
|
37
|
+
Loads CIFAR-10, applies preprocessing with separate train/eval transforms,
|
|
38
|
+
and returns DataLoaders.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
data_dir: Path to CIFAR-10 data directory
|
|
42
|
+
batch_size: Batch size for DataLoaders
|
|
43
|
+
train_transform: Transform to apply to training data
|
|
44
|
+
eval_transform: Transform to apply to validation and test data
|
|
45
|
+
device: Device to preload tensors onto. If None, data stays on CPU
|
|
46
|
+
and transforms are applied on-the-fly during iteration.
|
|
47
|
+
download: Whether to download the dataset if not present
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
Tuple of (train_loader, val_loader, test_loader)
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
# Load datasets with respective transforms
|
|
54
|
+
train_dataset_full = datasets.CIFAR10(
|
|
55
|
+
root=data_dir,
|
|
56
|
+
train=True,
|
|
57
|
+
download=download,
|
|
58
|
+
transform=train_transform,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
val_test_dataset_full = datasets.CIFAR10(
|
|
62
|
+
root=data_dir,
|
|
63
|
+
train=True,
|
|
64
|
+
download=download,
|
|
65
|
+
transform=eval_transform,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
test_dataset = datasets.CIFAR10(
|
|
69
|
+
root=data_dir,
|
|
70
|
+
train=False,
|
|
71
|
+
download=download,
|
|
72
|
+
transform=eval_transform,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
if device is not None:
|
|
76
|
+
# Preload entire dataset to device for faster training
|
|
77
|
+
X_train_full = torch.stack([img for img, _ in train_dataset_full]).to(device)
|
|
78
|
+
y_train_full = torch.tensor([label for _, label in train_dataset_full]).to(device)
|
|
79
|
+
|
|
80
|
+
X_val_test_full = torch.stack([img for img, _ in val_test_dataset_full]).to(device)
|
|
81
|
+
y_val_test_full = torch.tensor([label for _, label in val_test_dataset_full]).to(device)
|
|
82
|
+
|
|
83
|
+
X_test = torch.stack([img for img, _ in test_dataset]).to(device)
|
|
84
|
+
y_test = torch.tensor([label for _, label in test_dataset]).to(device)
|
|
85
|
+
|
|
86
|
+
# Train/val split (80/20)
|
|
87
|
+
n_train = int(0.8 * len(X_train_full))
|
|
88
|
+
indices = torch.randperm(len(X_train_full))
|
|
89
|
+
|
|
90
|
+
X_train = X_train_full[indices[:n_train]]
|
|
91
|
+
y_train = y_train_full[indices[:n_train]]
|
|
92
|
+
X_val = X_val_test_full[indices[n_train:]]
|
|
93
|
+
y_val = y_val_test_full[indices[n_train:]]
|
|
94
|
+
|
|
95
|
+
# TensorDatasets
|
|
96
|
+
train_tensor_dataset = torch.utils.data.TensorDataset(X_train, y_train)
|
|
97
|
+
val_tensor_dataset = torch.utils.data.TensorDataset(X_val, y_val)
|
|
98
|
+
test_tensor_dataset = torch.utils.data.TensorDataset(X_test, y_test)
|
|
99
|
+
|
|
100
|
+
else:
|
|
101
|
+
# Don't preload - use datasets directly for on-the-fly transforms
|
|
102
|
+
# Train/val split (80/20) using Subset
|
|
103
|
+
n_train = int(0.8 * len(train_dataset_full))
|
|
104
|
+
indices = torch.randperm(len(train_dataset_full)).tolist()
|
|
105
|
+
|
|
106
|
+
train_indices = indices[:n_train]
|
|
107
|
+
val_indices = indices[n_train:]
|
|
108
|
+
|
|
109
|
+
train_tensor_dataset = torch.utils.data.Subset(train_dataset_full, train_indices)
|
|
110
|
+
val_tensor_dataset = torch.utils.data.Subset(val_test_dataset_full, val_indices)
|
|
111
|
+
test_tensor_dataset = test_dataset
|
|
112
|
+
|
|
113
|
+
# DataLoaders
|
|
114
|
+
train_loader = DataLoader(
|
|
115
|
+
train_tensor_dataset,
|
|
116
|
+
batch_size=batch_size,
|
|
117
|
+
shuffle=True,
|
|
118
|
+
)
|
|
119
|
+
val_loader = DataLoader(
|
|
120
|
+
val_tensor_dataset,
|
|
121
|
+
batch_size=batch_size,
|
|
122
|
+
shuffle=False,
|
|
123
|
+
)
|
|
124
|
+
test_loader = DataLoader(
|
|
125
|
+
test_tensor_dataset,
|
|
126
|
+
batch_size=batch_size,
|
|
127
|
+
shuffle=False,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
return train_loader, val_loader, test_loader
|
|
131
|
+
|
|
132
|
+
|
|
25
133
|
if __name__ == '__main__':
|
|
26
134
|
|
|
27
135
|
download_cifar10_data()
|
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
'''Hyperparameter optimization utilities for CNN models using Optuna.
|
|
2
|
+
|
|
3
|
+
This module provides functions for building configurable CNN architectures
|
|
4
|
+
and running hyperparameter optimization with Optuna.
|
|
5
|
+
'''
|
|
6
|
+
|
|
7
|
+
from typing import Callable
|
|
8
|
+
|
|
9
|
+
import optuna
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn as nn
|
|
12
|
+
import torch.optim as optim
|
|
13
|
+
from torch.utils.data import DataLoader
|
|
14
|
+
|
|
15
|
+
from cifar10_tools.pytorch.data import make_data_loaders
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def create_cnn(
|
|
19
|
+
n_conv_blocks: int,
|
|
20
|
+
initial_filters: int,
|
|
21
|
+
n_fc_layers: int,
|
|
22
|
+
base_kernel_size: int,
|
|
23
|
+
conv_dropout_rate: float,
|
|
24
|
+
fc_dropout_rate: float,
|
|
25
|
+
pooling_strategy: str,
|
|
26
|
+
use_batch_norm: bool,
|
|
27
|
+
num_classes: int = 10,
|
|
28
|
+
in_channels: int = 3,
|
|
29
|
+
input_size: int = 32
|
|
30
|
+
) -> nn.Sequential:
|
|
31
|
+
'''Create a CNN with configurable architecture.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
n_conv_blocks: Number of convolutional blocks (1-5)
|
|
35
|
+
initial_filters: Number of filters in first conv layer (doubles each block)
|
|
36
|
+
n_fc_layers: Number of fully connected layers (1-8)
|
|
37
|
+
base_kernel_size: Base kernel size (decreases by 2 per block, min 3)
|
|
38
|
+
conv_dropout_rate: Dropout probability after convolutional blocks
|
|
39
|
+
fc_dropout_rate: Dropout probability in fully connected layers
|
|
40
|
+
pooling_strategy: Pooling type ('max' or 'avg')
|
|
41
|
+
use_batch_norm: Whether to use batch normalization
|
|
42
|
+
num_classes: Number of output classes (default: 10 for CIFAR-10)
|
|
43
|
+
in_channels: Number of input channels (default: 3 for RGB)
|
|
44
|
+
input_size: Input image size (default: 32 for CIFAR-10)
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
nn.Sequential model
|
|
48
|
+
'''
|
|
49
|
+
layers = []
|
|
50
|
+
current_channels = in_channels
|
|
51
|
+
current_size = input_size
|
|
52
|
+
|
|
53
|
+
# Convolutional blocks
|
|
54
|
+
for block_idx in range(n_conv_blocks):
|
|
55
|
+
out_channels = initial_filters * (2 ** block_idx)
|
|
56
|
+
kernel_size = max(3, base_kernel_size - 2 * block_idx)
|
|
57
|
+
padding = kernel_size // 2
|
|
58
|
+
|
|
59
|
+
# First conv in block
|
|
60
|
+
layers.append(nn.Conv2d(current_channels, out_channels, kernel_size=kernel_size, padding=padding))
|
|
61
|
+
|
|
62
|
+
if use_batch_norm:
|
|
63
|
+
layers.append(nn.BatchNorm2d(out_channels))
|
|
64
|
+
|
|
65
|
+
layers.append(nn.ReLU())
|
|
66
|
+
|
|
67
|
+
# Second conv in block
|
|
68
|
+
layers.append(nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding))
|
|
69
|
+
|
|
70
|
+
if use_batch_norm:
|
|
71
|
+
layers.append(nn.BatchNorm2d(out_channels))
|
|
72
|
+
|
|
73
|
+
layers.append(nn.ReLU())
|
|
74
|
+
|
|
75
|
+
# Pooling
|
|
76
|
+
if pooling_strategy == 'max':
|
|
77
|
+
layers.append(nn.MaxPool2d(2, 2))
|
|
78
|
+
else: # avg
|
|
79
|
+
layers.append(nn.AvgPool2d(2, 2))
|
|
80
|
+
|
|
81
|
+
layers.append(nn.Dropout(conv_dropout_rate))
|
|
82
|
+
|
|
83
|
+
current_channels = out_channels
|
|
84
|
+
current_size //= 2
|
|
85
|
+
|
|
86
|
+
# Calculate flattened size
|
|
87
|
+
final_channels = initial_filters * (2 ** (n_conv_blocks - 1))
|
|
88
|
+
flattened_size = final_channels * current_size * current_size
|
|
89
|
+
|
|
90
|
+
# Classifier - dynamic FC layers with halving pattern
|
|
91
|
+
layers.append(nn.Flatten())
|
|
92
|
+
|
|
93
|
+
# Generate FC layer sizes by halving from flattened_size
|
|
94
|
+
fc_sizes = []
|
|
95
|
+
current_fc_size = flattened_size // 2
|
|
96
|
+
for _ in range(n_fc_layers):
|
|
97
|
+
fc_sizes.append(max(10, current_fc_size)) # Minimum 10 units
|
|
98
|
+
current_fc_size //= 2
|
|
99
|
+
|
|
100
|
+
# Add FC layers
|
|
101
|
+
in_features = flattened_size
|
|
102
|
+
for fc_size in fc_sizes:
|
|
103
|
+
layers.append(nn.Linear(in_features, fc_size))
|
|
104
|
+
layers.append(nn.ReLU())
|
|
105
|
+
layers.append(nn.Dropout(fc_dropout_rate))
|
|
106
|
+
in_features = fc_size
|
|
107
|
+
|
|
108
|
+
# Output layer
|
|
109
|
+
layers.append(nn.Linear(in_features, num_classes))
|
|
110
|
+
|
|
111
|
+
return nn.Sequential(*layers)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def train_trial(
|
|
115
|
+
model: nn.Module,
|
|
116
|
+
optimizer: optim.Optimizer,
|
|
117
|
+
criterion: nn.Module,
|
|
118
|
+
train_loader: DataLoader,
|
|
119
|
+
val_loader: DataLoader,
|
|
120
|
+
n_epochs: int,
|
|
121
|
+
trial: optuna.Trial
|
|
122
|
+
) -> float:
|
|
123
|
+
'''Train a model for a single Optuna trial with pruning support.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
model: PyTorch model to train
|
|
127
|
+
optimizer: Optimizer for training
|
|
128
|
+
criterion: Loss function
|
|
129
|
+
train_loader: DataLoader for training data
|
|
130
|
+
val_loader: DataLoader for validation data
|
|
131
|
+
n_epochs: Number of epochs to train
|
|
132
|
+
trial: Optuna trial object for reporting and pruning
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
Best validation accuracy achieved during training
|
|
136
|
+
'''
|
|
137
|
+
best_val_accuracy = 0.0
|
|
138
|
+
|
|
139
|
+
for epoch in range(n_epochs):
|
|
140
|
+
|
|
141
|
+
# Training phase
|
|
142
|
+
model.train()
|
|
143
|
+
|
|
144
|
+
for images, labels in train_loader:
|
|
145
|
+
optimizer.zero_grad()
|
|
146
|
+
outputs = model(images)
|
|
147
|
+
loss = criterion(outputs, labels)
|
|
148
|
+
loss.backward()
|
|
149
|
+
optimizer.step()
|
|
150
|
+
|
|
151
|
+
# Validation phase
|
|
152
|
+
model.eval()
|
|
153
|
+
val_correct = 0
|
|
154
|
+
val_total = 0
|
|
155
|
+
|
|
156
|
+
with torch.no_grad():
|
|
157
|
+
for images, labels in val_loader:
|
|
158
|
+
outputs = model(images)
|
|
159
|
+
_, predicted = torch.max(outputs.data, 1)
|
|
160
|
+
val_total += labels.size(0)
|
|
161
|
+
val_correct += (predicted == labels).sum().item()
|
|
162
|
+
|
|
163
|
+
val_accuracy = 100 * val_correct / val_total
|
|
164
|
+
best_val_accuracy = max(best_val_accuracy, val_accuracy)
|
|
165
|
+
|
|
166
|
+
# Report intermediate value for pruning
|
|
167
|
+
trial.report(val_accuracy, epoch)
|
|
168
|
+
|
|
169
|
+
# Prune unpromising trials
|
|
170
|
+
if trial.should_prune():
|
|
171
|
+
raise optuna.TrialPruned()
|
|
172
|
+
|
|
173
|
+
return best_val_accuracy
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def create_objective(
|
|
177
|
+
data_dir,
|
|
178
|
+
train_transform,
|
|
179
|
+
eval_transform,
|
|
180
|
+
n_epochs: int,
|
|
181
|
+
device: torch.device,
|
|
182
|
+
num_classes: int = 10,
|
|
183
|
+
in_channels: int = 3
|
|
184
|
+
) -> Callable[[optuna.Trial], float]:
|
|
185
|
+
'''Create an Optuna objective function for CNN hyperparameter optimization.
|
|
186
|
+
|
|
187
|
+
This factory function creates a closure that captures the data loading parameters
|
|
188
|
+
and training configuration, returning an objective function suitable for Optuna.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
data_dir: Directory containing CIFAR-10 data
|
|
192
|
+
train_transform: Transform to apply to training data
|
|
193
|
+
eval_transform: Transform to apply to validation data
|
|
194
|
+
n_epochs: Number of epochs per trial
|
|
195
|
+
device: Device to train on (cuda or cpu)
|
|
196
|
+
num_classes: Number of output classes (default: 10)
|
|
197
|
+
in_channels: Number of input channels (default: 3 for RGB)
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
Objective function for optuna.Study.optimize()
|
|
201
|
+
|
|
202
|
+
Example:
|
|
203
|
+
>>> objective = create_objective(data_dir, transform, transform, n_epochs=50, device=device)
|
|
204
|
+
>>> study = optuna.create_study(direction='maximize')
|
|
205
|
+
>>> study.optimize(objective, n_trials=100)
|
|
206
|
+
'''
|
|
207
|
+
|
|
208
|
+
def objective(trial: optuna.Trial) -> float:
|
|
209
|
+
'''Optuna objective function for CNN hyperparameter optimization.'''
|
|
210
|
+
|
|
211
|
+
# Suggest hyperparameters
|
|
212
|
+
batch_size = trial.suggest_categorical('batch_size', [64, 128, 256, 512, 1024])
|
|
213
|
+
n_conv_blocks = trial.suggest_int('n_conv_blocks', 1, 5)
|
|
214
|
+
initial_filters = trial.suggest_categorical('initial_filters', [8, 16, 32, 64, 128])
|
|
215
|
+
n_fc_layers = trial.suggest_int('n_fc_layers', 1, 8)
|
|
216
|
+
base_kernel_size = trial.suggest_int('base_kernel_size', 3, 7)
|
|
217
|
+
conv_dropout_rate = trial.suggest_float('conv_dropout_rate', 0.0, 0.5)
|
|
218
|
+
fc_dropout_rate = trial.suggest_float('fc_dropout_rate', 0.2, 0.75)
|
|
219
|
+
pooling_strategy = trial.suggest_categorical('pooling_strategy', ['max', 'avg'])
|
|
220
|
+
use_batch_norm = trial.suggest_categorical('use_batch_norm', [True, False])
|
|
221
|
+
learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-1, log=True)
|
|
222
|
+
optimizer_name = trial.suggest_categorical('optimizer', ['Adam', 'SGD', 'RMSprop'])
|
|
223
|
+
|
|
224
|
+
# Create data loaders with suggested batch size
|
|
225
|
+
train_loader, val_loader, _ = make_data_loaders(
|
|
226
|
+
data_dir=data_dir,
|
|
227
|
+
batch_size=batch_size,
|
|
228
|
+
train_transform=train_transform,
|
|
229
|
+
eval_transform=eval_transform,
|
|
230
|
+
device=device,
|
|
231
|
+
download=False
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# Create model
|
|
235
|
+
model = create_cnn(
|
|
236
|
+
n_conv_blocks=n_conv_blocks,
|
|
237
|
+
initial_filters=initial_filters,
|
|
238
|
+
n_fc_layers=n_fc_layers,
|
|
239
|
+
base_kernel_size=base_kernel_size,
|
|
240
|
+
conv_dropout_rate=conv_dropout_rate,
|
|
241
|
+
fc_dropout_rate=fc_dropout_rate,
|
|
242
|
+
pooling_strategy=pooling_strategy,
|
|
243
|
+
use_batch_norm=use_batch_norm,
|
|
244
|
+
num_classes=num_classes,
|
|
245
|
+
in_channels=in_channels
|
|
246
|
+
).to(device)
|
|
247
|
+
|
|
248
|
+
# Define optimizer
|
|
249
|
+
if optimizer_name == 'Adam':
|
|
250
|
+
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
|
251
|
+
|
|
252
|
+
elif optimizer_name == 'SGD':
|
|
253
|
+
momentum = trial.suggest_float('sgd_momentum', 0.8, 0.99)
|
|
254
|
+
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
|
|
255
|
+
|
|
256
|
+
else: # RMSprop
|
|
257
|
+
optimizer = optim.RMSprop(model.parameters(), lr=learning_rate)
|
|
258
|
+
|
|
259
|
+
criterion = nn.CrossEntropyLoss()
|
|
260
|
+
|
|
261
|
+
# Train model and return best validation accuracy
|
|
262
|
+
try:
|
|
263
|
+
return train_trial(
|
|
264
|
+
model=model,
|
|
265
|
+
optimizer=optimizer,
|
|
266
|
+
criterion=criterion,
|
|
267
|
+
train_loader=train_loader,
|
|
268
|
+
val_loader=val_loader,
|
|
269
|
+
n_epochs=n_epochs,
|
|
270
|
+
trial=trial
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
except torch.cuda.OutOfMemoryError:
|
|
274
|
+
# Clear CUDA cache and skip this trial
|
|
275
|
+
torch.cuda.empty_cache()
|
|
276
|
+
raise optuna.TrialPruned(f'CUDA OOM with params: {trial.params}')
|
|
277
|
+
|
|
278
|
+
return objective
|
|
@@ -10,18 +10,18 @@ def plot_sample_images(
|
|
|
10
10
|
class_names: list[str],
|
|
11
11
|
nrows: int = 2,
|
|
12
12
|
ncols: int = 5,
|
|
13
|
-
figsize: tuple[float, float] | None = None
|
|
14
|
-
cmap: str = 'gray'
|
|
13
|
+
figsize: tuple[float, float] | None = None
|
|
15
14
|
) -> tuple[plt.Figure, np.ndarray]:
|
|
16
15
|
'''Plot sample images from a dataset.
|
|
17
16
|
|
|
17
|
+
Automatically handles both grayscale (1 channel) and RGB (3 channel) images.
|
|
18
|
+
|
|
18
19
|
Args:
|
|
19
20
|
dataset: PyTorch dataset containing (image, label) tuples.
|
|
20
21
|
class_names: List of class names for labeling.
|
|
21
22
|
nrows: Number of rows in the grid.
|
|
22
23
|
ncols: Number of columns in the grid.
|
|
23
24
|
figsize: Figure size (width, height). Defaults to (ncols*1.5, nrows*1.5).
|
|
24
|
-
cmap: Colormap for displaying images.
|
|
25
25
|
|
|
26
26
|
Returns:
|
|
27
27
|
Tuple of (figure, axes array).
|
|
@@ -36,11 +36,24 @@ def plot_sample_images(
|
|
|
36
36
|
# Get image and label from dataset
|
|
37
37
|
img, label = dataset[i]
|
|
38
38
|
|
|
39
|
-
# Unnormalize
|
|
39
|
+
# Unnormalize for plotting
|
|
40
40
|
img = img * 0.5 + 0.5
|
|
41
|
-
img = img.numpy()
|
|
41
|
+
img = img.numpy()
|
|
42
|
+
|
|
43
|
+
# Handle grayscale vs RGB images
|
|
44
|
+
if img.shape[0] == 1:
|
|
45
|
+
|
|
46
|
+
# Grayscale: squeeze channel dimension
|
|
47
|
+
img = img.squeeze()
|
|
48
|
+
ax.imshow(img, cmap='gray')
|
|
49
|
+
|
|
50
|
+
else:
|
|
51
|
+
|
|
52
|
+
# RGB: transpose from (C, H, W) to (H, W, C)
|
|
53
|
+
img = np.transpose(img, (1, 2, 0))
|
|
54
|
+
ax.imshow(img)
|
|
55
|
+
|
|
42
56
|
ax.set_title(class_names[label])
|
|
43
|
-
ax.imshow(img, cmap=cmap)
|
|
44
57
|
ax.axis('off')
|
|
45
58
|
|
|
46
59
|
plt.tight_layout()
|
|
@@ -209,7 +222,7 @@ def plot_evaluation_curves(
|
|
|
209
222
|
roc_auc = auc(fpr, tpr)
|
|
210
223
|
ax1.plot(fpr, tpr, label=class_name)
|
|
211
224
|
|
|
212
|
-
ax1.plot([0, 1], [0, 1], 'k--', label='
|
|
225
|
+
ax1.plot([0, 1], [0, 1], 'k--', label='random classifier')
|
|
213
226
|
ax1.set_xlabel('False positive rate')
|
|
214
227
|
ax1.set_ylabel('True positive rate')
|
|
215
228
|
ax1.legend(loc='lower right', fontsize=12)
|
|
@@ -235,4 +248,60 @@ def plot_evaluation_curves(
|
|
|
235
248
|
|
|
236
249
|
plt.tight_layout()
|
|
237
250
|
|
|
238
|
-
return fig, (ax1, ax2)
|
|
251
|
+
return fig, (ax1, ax2)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def plot_optimization_results(
|
|
255
|
+
study,
|
|
256
|
+
figsize: tuple[float, float] = (12, 4)
|
|
257
|
+
) -> tuple[plt.Figure, np.ndarray]:
|
|
258
|
+
'''Plot Optuna optimization history and hyperparameter importance.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
study: Optuna study object with completed trials.
|
|
262
|
+
figsize: Figure size (width, height).
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
Tuple of (figure, axes array).
|
|
266
|
+
'''
|
|
267
|
+
import optuna
|
|
268
|
+
|
|
269
|
+
fig, axes = plt.subplots(1, 2, figsize=figsize)
|
|
270
|
+
|
|
271
|
+
# Optimization history
|
|
272
|
+
axes[0].set_title('Optimization History')
|
|
273
|
+
|
|
274
|
+
trial_numbers = [t.number for t in study.trials if t.value is not None]
|
|
275
|
+
trial_values = [t.value for t in study.trials if t.value is not None]
|
|
276
|
+
|
|
277
|
+
axes[0].plot(trial_numbers, trial_values, 'ko-', alpha=0.6)
|
|
278
|
+
axes[0].axhline(
|
|
279
|
+
y=study.best_value,
|
|
280
|
+
color='r', linestyle='--', label=f'Best: {study.best_value:.2f}%'
|
|
281
|
+
)
|
|
282
|
+
axes[0].set_xlabel('Trial')
|
|
283
|
+
axes[0].set_ylabel('Validation Accuracy (%)')
|
|
284
|
+
axes[0].legend()
|
|
285
|
+
|
|
286
|
+
# Hyperparameter importance (if enough trials completed)
|
|
287
|
+
axes[1].set_title('Hyperparameter Importance')
|
|
288
|
+
completed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
|
|
289
|
+
|
|
290
|
+
if len(completed_trials) >= 5:
|
|
291
|
+
importance = optuna.importance.get_param_importances(study)
|
|
292
|
+
params = list(importance.keys())
|
|
293
|
+
values = list(importance.values())
|
|
294
|
+
|
|
295
|
+
axes[1].set_xlabel('Importance')
|
|
296
|
+
axes[1].barh(params, values, color='black')
|
|
297
|
+
|
|
298
|
+
else:
|
|
299
|
+
axes[1].text(
|
|
300
|
+
0.5, 0.5,
|
|
301
|
+
'Not enough completed trials\nfor importance analysis',
|
|
302
|
+
ha='center', va='center', transform=axes[1].transAxes
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
plt.tight_layout()
|
|
306
|
+
|
|
307
|
+
return fig, axes
|
|
@@ -12,11 +12,28 @@ def train_model(
|
|
|
12
12
|
criterion: nn.Module,
|
|
13
13
|
optimizer: optim.Optimizer,
|
|
14
14
|
epochs: int = 10,
|
|
15
|
-
print_every: int = 1
|
|
15
|
+
print_every: int = 1,
|
|
16
|
+
device: torch.device | str | None = None
|
|
16
17
|
) -> dict[str, list[float]]:
|
|
17
18
|
'''Training loop for PyTorch classification model.
|
|
18
19
|
|
|
19
|
-
|
|
20
|
+
Handles both pre-loaded GPU data and lazy-loading (CPU data moved per-batch).
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
model: PyTorch model to train.
|
|
24
|
+
train_loader: DataLoader for training data.
|
|
25
|
+
val_loader: DataLoader for validation data.
|
|
26
|
+
criterion: Loss function.
|
|
27
|
+
optimizer: Optimizer.
|
|
28
|
+
epochs: Number of training epochs.
|
|
29
|
+
print_every: Print progress every n epochs.
|
|
30
|
+
device: Device to move data to per-batch. If None, assumes data is
|
|
31
|
+
already on the correct device (GPU pre-loading). If specified,
|
|
32
|
+
data will be moved to this device per-batch (lazy loading).
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Dictionary containing training history with keys:
|
|
36
|
+
'train_loss', 'val_loss', 'train_accuracy', 'val_accuracy'.
|
|
20
37
|
'''
|
|
21
38
|
|
|
22
39
|
history = {'train_loss': [], 'val_loss': [], 'train_accuracy': [], 'val_accuracy': []}
|
|
@@ -30,6 +47,11 @@ def train_model(
|
|
|
30
47
|
total = 0
|
|
31
48
|
|
|
32
49
|
for images, labels in train_loader:
|
|
50
|
+
|
|
51
|
+
# Move to device if lazy loading
|
|
52
|
+
if device is not None:
|
|
53
|
+
images = images.to(device, non_blocking=True)
|
|
54
|
+
labels = labels.to(device, non_blocking=True)
|
|
33
55
|
|
|
34
56
|
# Forward pass
|
|
35
57
|
optimizer.zero_grad()
|
|
@@ -59,6 +81,11 @@ def train_model(
|
|
|
59
81
|
with torch.no_grad():
|
|
60
82
|
|
|
61
83
|
for images, labels in val_loader:
|
|
84
|
+
|
|
85
|
+
# Move to device if lazy loading
|
|
86
|
+
if device is not None:
|
|
87
|
+
images = images.to(device, non_blocking=True)
|
|
88
|
+
labels = labels.to(device, non_blocking=True)
|
|
62
89
|
|
|
63
90
|
outputs = model(images)
|
|
64
91
|
loss = criterion(outputs, labels)
|
|
@@ -88,6 +115,4 @@ def train_model(
|
|
|
88
115
|
f'val_accuracy: {val_accuracy:.2f}%'
|
|
89
116
|
)
|
|
90
117
|
|
|
91
|
-
print('\nTraining complete.')
|
|
92
|
-
|
|
93
118
|
return history
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: cifar10_tools
|
|
3
|
+
Version: 0.4.0
|
|
4
|
+
Summary: Tools for training neural networks on the CIFAR-10 task with PyTorch and TensorFlow
|
|
5
|
+
License: GPLv3
|
|
6
|
+
License-File: LICENSE
|
|
7
|
+
Keywords: Python,Machine learning,Deep learning,CNNs,Computer vision,Image classification,CIFAR-10
|
|
8
|
+
Author: gperdrizet
|
|
9
|
+
Author-email: george@perdrizet.org
|
|
10
|
+
Requires-Python: >=3.10,<3.13
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: Intended Audience :: Education
|
|
14
|
+
Classifier: Intended Audience :: Science/Research
|
|
15
|
+
Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
|
|
16
|
+
Classifier: License :: Other/Proprietary License
|
|
17
|
+
Classifier: Operating System :: OS Independent
|
|
18
|
+
Classifier: Programming Language :: Python :: 3
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
21
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
22
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
23
|
+
Classifier: Topic :: Scientific/Engineering :: Image Recognition
|
|
24
|
+
Provides-Extra: tensorflow
|
|
25
|
+
Requires-Dist: numpy (>=1.24)
|
|
26
|
+
Requires-Dist: torch (>=2.0)
|
|
27
|
+
Requires-Dist: torchvision (>=0.15)
|
|
28
|
+
Project-URL: Documentation, https://gperdrizet.github.io/CIFAR10/README.md
|
|
29
|
+
Project-URL: Homepage, https://github.com/gperdrizet/CIFAR10
|
|
30
|
+
Project-URL: Issues, https://github.com/gperdrizet/CIFAR10/issues
|
|
31
|
+
Project-URL: PyPI, https://pypi.org/project/cifar10_tools
|
|
32
|
+
Project-URL: Repository, https://github.com/gperdrizet/CIFAR10
|
|
33
|
+
Description-Content-Type: text/markdown
|
|
34
|
+
|
|
35
|
+
# PyTorch: CIFAR-10 Demonstration
|
|
36
|
+
|
|
37
|
+
A progressive deep learning tutorial for image classification on the CIFAR-10 dataset using PyTorch. This project demonstrates the evolution from basic deep neural networks to optimized convolutional neural networks with data augmentation. It also provides a set of utility functions as a PyPI package for use in other projects.
|
|
38
|
+
|
|
39
|
+
[View on PyPI](https://pypi.org/project/cifar10_tools)
|
|
40
|
+
|
|
41
|
+
## Installation
|
|
42
|
+
|
|
43
|
+
Install the helper tools package locally in editable mode:
|
|
44
|
+
|
|
45
|
+
```bash
|
|
46
|
+
pip install -e .
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
## Project Overview
|
|
50
|
+
|
|
51
|
+
This repository contains a series of Jupyter notebooks that progressively build more sophisticated neural network architectures for the CIFAR-10 image classification task. Each notebook builds upon concepts from the previous one, demonstrating key deep learning techniques.
|
|
52
|
+
|
|
53
|
+
## Notebooks
|
|
54
|
+
|
|
55
|
+
| Notebook | Description |
|
|
56
|
+
|----------|-------------|
|
|
57
|
+
| [01-DNN.ipynb](notebooks/01-DNN.ipynb) | **Deep Neural Network** - Baseline fully-connected DNN classifier using `nn.Sequential`. Establishes a performance baseline with a simple architecture. |
|
|
58
|
+
| [02-CNN.ipynb](notebooks/02-CNN.ipynb) | **Convolutional Neural Network** - Introduction to CNNs with convolutional and pooling layers using `nn.Sequential`. Demonstrates the advantage of CNNs over DNNs for image tasks. |
|
|
59
|
+
| [03-RGB-CNN.ipynb](notebooks/03-RGB-CNN.ipynb) | **RGB CNN** - CNN classifier that utilizes full RGB color information instead of grayscale, improving feature extraction from color images. |
|
|
60
|
+
| [04-optimized-CNN.ipynb](notebooks/04-optimized-CNN.ipynb) | **Hyperparameter Optimization** - Uses Optuna for automated hyperparameter tuning to find optimal network architecture and training parameters. |
|
|
61
|
+
| [05-augmented-CNN.ipynb](notebooks/05-augmented-CNN.ipynb) | **Data Augmentation** - Trains the optimized CNN architecture with image augmentation techniques for improved generalization and robustness. |
|
|
62
|
+
|
|
63
|
+
## Requirements
|
|
64
|
+
|
|
65
|
+
- Python >=3.10, <3.13
|
|
66
|
+
- PyTorch >=2.0
|
|
67
|
+
- torchvision >=0.15
|
|
68
|
+
- numpy >=1.24
|
|
69
|
+
|
|
70
|
+
## License
|
|
71
|
+
|
|
72
|
+
This project is licensed under the GPLv3 License - see the [LICENSE](LICENSE) file for details.
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
cifar10_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
cifar10_tools/pytorch/__init__.py,sha256=4er-aMGK-MZTlkH3Owz3x-Pz_Gl_NjplKwOBYdBA1p0,909
|
|
3
|
+
cifar10_tools/pytorch/data.py,sha256=ZJb_EYxHPh6wsnAtzRcDFVVZaa3ChAbnC5IHaWaf0Ls,4272
|
|
4
|
+
cifar10_tools/pytorch/evaluation.py,sha256=i4tRYOqWATVqQVkWT_fATWRbzo9ziX2DDkXKPaiQlFE,923
|
|
5
|
+
cifar10_tools/pytorch/hyperparameter_optimization.py,sha256=kosd937gLC_QfamC1dVm9DQ2P6VYVqETMlF6t3de23c,9671
|
|
6
|
+
cifar10_tools/pytorch/plotting.py,sha256=SB50bwY4qhvYu_cVNT7EAE2vwOI8-0pxwu7jwGTJRas,9550
|
|
7
|
+
cifar10_tools/pytorch/training.py,sha256=spam_Q1G1ZAoheMMKY26RHl6YhIam8pW6A7Df7oS1to,3824
|
|
8
|
+
cifar10_tools/tensorflow/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
|
+
cifar10_tools-0.4.0.dist-info/METADATA,sha256=X8Ktr3qlNTWuJzrOUl_EktGs5CAJDY_LIPVLNf3d-Vw,3670
|
|
10
|
+
cifar10_tools-0.4.0.dist-info/WHEEL,sha256=kJCRJT_g0adfAJzTx2GUMmS80rTJIVHRCfG0DQgLq3o,88
|
|
11
|
+
cifar10_tools-0.4.0.dist-info/licenses/LICENSE,sha256=wtHfRwmCF5-_XUmYwrBKwJkGipvHVmh7GXJOKKeOe2U,1073
|
|
12
|
+
cifar10_tools-0.4.0.dist-info/RECORD,,
|
|
@@ -1,35 +0,0 @@
|
|
|
1
|
-
Metadata-Version: 2.4
|
|
2
|
-
Name: cifar10_tools
|
|
3
|
-
Version: 0.2.0
|
|
4
|
-
Summary: Tools for training neural networks on the CIFAR-10 task with PyTorch and TensorFlow
|
|
5
|
-
License: GPLv3
|
|
6
|
-
License-File: LICENSE
|
|
7
|
-
Keywords: Python,Machine learning,Deep learning,CNNs,Computer vision,Image classification,CIFAR-10
|
|
8
|
-
Author: gperdrizet
|
|
9
|
-
Author-email: george@perdrizet.org
|
|
10
|
-
Requires-Python: >=3.10,<3.13
|
|
11
|
-
Classifier: Development Status :: 3 - Alpha
|
|
12
|
-
Classifier: Intended Audience :: Developers
|
|
13
|
-
Classifier: Intended Audience :: Education
|
|
14
|
-
Classifier: Intended Audience :: Science/Research
|
|
15
|
-
Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
|
|
16
|
-
Classifier: License :: Other/Proprietary License
|
|
17
|
-
Classifier: Operating System :: OS Independent
|
|
18
|
-
Classifier: Programming Language :: Python :: 3
|
|
19
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
20
|
-
Classifier: Programming Language :: Python :: 3.11
|
|
21
|
-
Classifier: Programming Language :: Python :: 3.12
|
|
22
|
-
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
23
|
-
Classifier: Topic :: Scientific/Engineering :: Image Recognition
|
|
24
|
-
Provides-Extra: tensorflow
|
|
25
|
-
Requires-Dist: numpy (>=1.24)
|
|
26
|
-
Requires-Dist: torch (>=2.0)
|
|
27
|
-
Requires-Dist: torchvision (>=0.15)
|
|
28
|
-
Project-URL: Documentation, https://gperdrizet.github.io/CIFAR10/README.md
|
|
29
|
-
Project-URL: Homepage, https://github.com/gperdrizet/CIFAR10
|
|
30
|
-
Project-URL: Issues, https://github.com/gperdrizet/CIFAR10/issues
|
|
31
|
-
Project-URL: PyPI, https://pypi.org/project/cifar10_tools
|
|
32
|
-
Project-URL: Repository, https://github.com/gperdrizet/CIFAR10
|
|
33
|
-
Description-Content-Type: text/markdown
|
|
34
|
-
|
|
35
|
-
# PyTorch: CIFAR10 demonstration
|
|
@@ -1,11 +0,0 @@
|
|
|
1
|
-
cifar10_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
cifar10_tools/pytorch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
-
cifar10_tools/pytorch/data.py,sha256=09zodpjto0xLq95tDAyq57CFh6MSYRuUBPcMmQcyKZM,626
|
|
4
|
-
cifar10_tools/pytorch/evaluation.py,sha256=i4tRYOqWATVqQVkWT_fATWRbzo9ziX2DDkXKPaiQlFE,923
|
|
5
|
-
cifar10_tools/pytorch/plotting.py,sha256=B1ifJxbSEDpInnVk9c3o1fjVx534TPPKTWM5iusyzrE,7494
|
|
6
|
-
cifar10_tools/pytorch/training.py,sha256=Sg6NlBT_DTyLzf-Ls3bYI8-8AwGFJblRj0MDnUmGP3Q,2642
|
|
7
|
-
cifar10_tools/tensorflow/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
|
-
cifar10_tools-0.2.0.dist-info/METADATA,sha256=3s6_5lP8rAnEu5F9r5YKU-EqUi9UO3mNUFK1ikVgUfc,1580
|
|
9
|
-
cifar10_tools-0.2.0.dist-info/WHEEL,sha256=kJCRJT_g0adfAJzTx2GUMmS80rTJIVHRCfG0DQgLq3o,88
|
|
10
|
-
cifar10_tools-0.2.0.dist-info/licenses/LICENSE,sha256=wtHfRwmCF5-_XUmYwrBKwJkGipvHVmh7GXJOKKeOe2U,1073
|
|
11
|
-
cifar10_tools-0.2.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|