cifar10-tools 0.2.0__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,72 @@
1
+ Metadata-Version: 2.4
2
+ Name: cifar10_tools
3
+ Version: 0.4.0
4
+ Summary: Tools for training neural networks on the CIFAR-10 task with PyTorch and TensorFlow
5
+ License: GPLv3
6
+ License-File: LICENSE
7
+ Keywords: Python,Machine learning,Deep learning,CNNs,Computer vision,Image classification,CIFAR-10
8
+ Author: gperdrizet
9
+ Author-email: george@perdrizet.org
10
+ Requires-Python: >=3.10,<3.13
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: Intended Audience :: Education
14
+ Classifier: Intended Audience :: Science/Research
15
+ Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
16
+ Classifier: License :: Other/Proprietary License
17
+ Classifier: Operating System :: OS Independent
18
+ Classifier: Programming Language :: Python :: 3
19
+ Classifier: Programming Language :: Python :: 3.10
20
+ Classifier: Programming Language :: Python :: 3.11
21
+ Classifier: Programming Language :: Python :: 3.12
22
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
23
+ Classifier: Topic :: Scientific/Engineering :: Image Recognition
24
+ Provides-Extra: tensorflow
25
+ Requires-Dist: numpy (>=1.24)
26
+ Requires-Dist: torch (>=2.0)
27
+ Requires-Dist: torchvision (>=0.15)
28
+ Project-URL: Documentation, https://gperdrizet.github.io/CIFAR10/README.md
29
+ Project-URL: Homepage, https://github.com/gperdrizet/CIFAR10
30
+ Project-URL: Issues, https://github.com/gperdrizet/CIFAR10/issues
31
+ Project-URL: PyPI, https://pypi.org/project/cifar10_tools
32
+ Project-URL: Repository, https://github.com/gperdrizet/CIFAR10
33
+ Description-Content-Type: text/markdown
34
+
35
+ # PyTorch: CIFAR-10 Demonstration
36
+
37
+ A progressive deep learning tutorial for image classification on the CIFAR-10 dataset using PyTorch. This project demonstrates the evolution from basic deep neural networks to optimized convolutional neural networks with data augmentation. It also provides a set of utility functions as a PyPI package for use in other projects.
38
+
39
+ [View on PyPI](https://pypi.org/project/cifar10_tools)
40
+
41
+ ## Installation
42
+
43
+ Install the helper tools package locally in editable mode:
44
+
45
+ ```bash
46
+ pip install -e .
47
+ ```
48
+
49
+ ## Project Overview
50
+
51
+ This repository contains a series of Jupyter notebooks that progressively build more sophisticated neural network architectures for the CIFAR-10 image classification task. Each notebook builds upon concepts from the previous one, demonstrating key deep learning techniques.
52
+
53
+ ## Notebooks
54
+
55
+ | Notebook | Description |
56
+ |----------|-------------|
57
+ | [01-DNN.ipynb](notebooks/01-DNN.ipynb) | **Deep Neural Network** - Baseline fully-connected DNN classifier using `nn.Sequential`. Establishes a performance baseline with a simple architecture. |
58
+ | [02-CNN.ipynb](notebooks/02-CNN.ipynb) | **Convolutional Neural Network** - Introduction to CNNs with convolutional and pooling layers using `nn.Sequential`. Demonstrates the advantage of CNNs over DNNs for image tasks. |
59
+ | [03-RGB-CNN.ipynb](notebooks/03-RGB-CNN.ipynb) | **RGB CNN** - CNN classifier that utilizes full RGB color information instead of grayscale, improving feature extraction from color images. |
60
+ | [04-optimized-CNN.ipynb](notebooks/04-optimized-CNN.ipynb) | **Hyperparameter Optimization** - Uses Optuna for automated hyperparameter tuning to find optimal network architecture and training parameters. |
61
+ | [05-augmented-CNN.ipynb](notebooks/05-augmented-CNN.ipynb) | **Data Augmentation** - Trains the optimized CNN architecture with image augmentation techniques for improved generalization and robustness. |
62
+
63
+ ## Requirements
64
+
65
+ - Python >=3.10, <3.13
66
+ - PyTorch >=2.0
67
+ - torchvision >=0.15
68
+ - numpy >=1.24
69
+
70
+ ## License
71
+
72
+ This project is licensed under the GPLv3 License - see the [LICENSE](LICENSE) file for details.
@@ -0,0 +1,38 @@
1
+ # PyTorch: CIFAR-10 Demonstration
2
+
3
+ A progressive deep learning tutorial for image classification on the CIFAR-10 dataset using PyTorch. This project demonstrates the evolution from basic deep neural networks to optimized convolutional neural networks with data augmentation. It also provides a set of utility functions as a PyPI package for use in other projects.
4
+
5
+ [View on PyPI](https://pypi.org/project/cifar10_tools)
6
+
7
+ ## Installation
8
+
9
+ Install the helper tools package locally in editable mode:
10
+
11
+ ```bash
12
+ pip install -e .
13
+ ```
14
+
15
+ ## Project Overview
16
+
17
+ This repository contains a series of Jupyter notebooks that progressively build more sophisticated neural network architectures for the CIFAR-10 image classification task. Each notebook builds upon concepts from the previous one, demonstrating key deep learning techniques.
18
+
19
+ ## Notebooks
20
+
21
+ | Notebook | Description |
22
+ |----------|-------------|
23
+ | [01-DNN.ipynb](notebooks/01-DNN.ipynb) | **Deep Neural Network** - Baseline fully-connected DNN classifier using `nn.Sequential`. Establishes a performance baseline with a simple architecture. |
24
+ | [02-CNN.ipynb](notebooks/02-CNN.ipynb) | **Convolutional Neural Network** - Introduction to CNNs with convolutional and pooling layers using `nn.Sequential`. Demonstrates the advantage of CNNs over DNNs for image tasks. |
25
+ | [03-RGB-CNN.ipynb](notebooks/03-RGB-CNN.ipynb) | **RGB CNN** - CNN classifier that utilizes full RGB color information instead of grayscale, improving feature extraction from color images. |
26
+ | [04-optimized-CNN.ipynb](notebooks/04-optimized-CNN.ipynb) | **Hyperparameter Optimization** - Uses Optuna for automated hyperparameter tuning to find optimal network architecture and training parameters. |
27
+ | [05-augmented-CNN.ipynb](notebooks/05-augmented-CNN.ipynb) | **Data Augmentation** - Trains the optimized CNN architecture with image augmentation techniques for improved generalization and robustness. |
28
+
29
+ ## Requirements
30
+
31
+ - Python >=3.10, <3.13
32
+ - PyTorch >=2.0
33
+ - torchvision >=0.15
34
+ - numpy >=1.24
35
+
36
+ ## License
37
+
38
+ This project is licensed under the GPLv3 License - see the [LICENSE](LICENSE) file for details.
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "cifar10_tools"
7
- version = "0.2.0"
7
+ version = "0.4.0"
8
8
  description = "Tools for training neural networks on the CIFAR-10 task with PyTorch and TensorFlow"
9
9
  authors = ["gperdrizet <george@perdrizet.org>"]
10
10
  readme = "README.md"
@@ -0,0 +1,33 @@
1
+ '''PyTorch utilities for CIFAR-10 classification.'''
2
+
3
+ from cifar10_tools.pytorch.data import download_cifar10_data
4
+ from cifar10_tools.pytorch.evaluation import evaluate_model
5
+ from cifar10_tools.pytorch.training import train_model
6
+ from cifar10_tools.pytorch.plotting import (
7
+ plot_sample_images,
8
+ plot_learning_curves,
9
+ plot_confusion_matrix,
10
+ plot_class_probability_distributions,
11
+ plot_evaluation_curves,
12
+ plot_optimization_results
13
+ )
14
+ from cifar10_tools.pytorch.hyperparameter_optimization import (
15
+ create_cnn,
16
+ train_trial,
17
+ create_objective
18
+ )
19
+
20
+ __all__ = [
21
+ 'download_cifar10_data',
22
+ 'evaluate_model',
23
+ 'train_model',
24
+ 'plot_sample_images',
25
+ 'plot_learning_curves',
26
+ 'plot_confusion_matrix',
27
+ 'plot_class_probability_distributions',
28
+ 'plot_evaluation_curves',
29
+ 'plot_optimization_results',
30
+ 'create_cnn',
31
+ 'train_trial',
32
+ 'create_objective'
33
+ ]
@@ -0,0 +1,135 @@
1
+ '''Data loading and preprocessing functions for CIFAR-10 dataset.'''
2
+
3
+ from pathlib import Path
4
+ import torch
5
+ from torchvision import datasets, transforms
6
+ from torch.utils.data import DataLoader
7
+
8
+
9
+ def download_cifar10_data(data_dir: str='data/pytorch/cifar10'):
10
+ '''Download CIFAR-10 dataset using torchvision.datasets.'''
11
+
12
+ data_dir = Path(data_dir)
13
+ data_dir.mkdir(parents=True, exist_ok=True)
14
+
15
+ _ = datasets.CIFAR10(
16
+ root=data_dir,
17
+ train=True,
18
+ download=True
19
+ )
20
+
21
+ _ = datasets.CIFAR10(
22
+ root=data_dir,
23
+ train=False,
24
+ download=True
25
+ )
26
+
27
+
28
+ def make_data_loaders(
29
+ data_dir: Path,
30
+ batch_size: int,
31
+ train_transform: transforms.Compose,
32
+ eval_transform: transforms.Compose,
33
+ device: torch.device | None = None,
34
+ download: bool = False,
35
+ ):
36
+ """
37
+ Loads CIFAR-10, applies preprocessing with separate train/eval transforms,
38
+ and returns DataLoaders.
39
+
40
+ Args:
41
+ data_dir: Path to CIFAR-10 data directory
42
+ batch_size: Batch size for DataLoaders
43
+ train_transform: Transform to apply to training data
44
+ eval_transform: Transform to apply to validation and test data
45
+ device: Device to preload tensors onto. If None, data stays on CPU
46
+ and transforms are applied on-the-fly during iteration.
47
+ download: Whether to download the dataset if not present
48
+
49
+ Returns:
50
+ Tuple of (train_loader, val_loader, test_loader)
51
+ """
52
+
53
+ # Load datasets with respective transforms
54
+ train_dataset_full = datasets.CIFAR10(
55
+ root=data_dir,
56
+ train=True,
57
+ download=download,
58
+ transform=train_transform,
59
+ )
60
+
61
+ val_test_dataset_full = datasets.CIFAR10(
62
+ root=data_dir,
63
+ train=True,
64
+ download=download,
65
+ transform=eval_transform,
66
+ )
67
+
68
+ test_dataset = datasets.CIFAR10(
69
+ root=data_dir,
70
+ train=False,
71
+ download=download,
72
+ transform=eval_transform,
73
+ )
74
+
75
+ if device is not None:
76
+ # Preload entire dataset to device for faster training
77
+ X_train_full = torch.stack([img for img, _ in train_dataset_full]).to(device)
78
+ y_train_full = torch.tensor([label for _, label in train_dataset_full]).to(device)
79
+
80
+ X_val_test_full = torch.stack([img for img, _ in val_test_dataset_full]).to(device)
81
+ y_val_test_full = torch.tensor([label for _, label in val_test_dataset_full]).to(device)
82
+
83
+ X_test = torch.stack([img for img, _ in test_dataset]).to(device)
84
+ y_test = torch.tensor([label for _, label in test_dataset]).to(device)
85
+
86
+ # Train/val split (80/20)
87
+ n_train = int(0.8 * len(X_train_full))
88
+ indices = torch.randperm(len(X_train_full))
89
+
90
+ X_train = X_train_full[indices[:n_train]]
91
+ y_train = y_train_full[indices[:n_train]]
92
+ X_val = X_val_test_full[indices[n_train:]]
93
+ y_val = y_val_test_full[indices[n_train:]]
94
+
95
+ # TensorDatasets
96
+ train_tensor_dataset = torch.utils.data.TensorDataset(X_train, y_train)
97
+ val_tensor_dataset = torch.utils.data.TensorDataset(X_val, y_val)
98
+ test_tensor_dataset = torch.utils.data.TensorDataset(X_test, y_test)
99
+
100
+ else:
101
+ # Don't preload - use datasets directly for on-the-fly transforms
102
+ # Train/val split (80/20) using Subset
103
+ n_train = int(0.8 * len(train_dataset_full))
104
+ indices = torch.randperm(len(train_dataset_full)).tolist()
105
+
106
+ train_indices = indices[:n_train]
107
+ val_indices = indices[n_train:]
108
+
109
+ train_tensor_dataset = torch.utils.data.Subset(train_dataset_full, train_indices)
110
+ val_tensor_dataset = torch.utils.data.Subset(val_test_dataset_full, val_indices)
111
+ test_tensor_dataset = test_dataset
112
+
113
+ # DataLoaders
114
+ train_loader = DataLoader(
115
+ train_tensor_dataset,
116
+ batch_size=batch_size,
117
+ shuffle=True,
118
+ )
119
+ val_loader = DataLoader(
120
+ val_tensor_dataset,
121
+ batch_size=batch_size,
122
+ shuffle=False,
123
+ )
124
+ test_loader = DataLoader(
125
+ test_tensor_dataset,
126
+ batch_size=batch_size,
127
+ shuffle=False,
128
+ )
129
+
130
+ return train_loader, val_loader, test_loader
131
+
132
+
133
+ if __name__ == '__main__':
134
+
135
+ download_cifar10_data()
@@ -0,0 +1,278 @@
1
+ '''Hyperparameter optimization utilities for CNN models using Optuna.
2
+
3
+ This module provides functions for building configurable CNN architectures
4
+ and running hyperparameter optimization with Optuna.
5
+ '''
6
+
7
+ from typing import Callable
8
+
9
+ import optuna
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.optim as optim
13
+ from torch.utils.data import DataLoader
14
+
15
+ from cifar10_tools.pytorch.data import make_data_loaders
16
+
17
+
18
+ def create_cnn(
19
+ n_conv_blocks: int,
20
+ initial_filters: int,
21
+ n_fc_layers: int,
22
+ base_kernel_size: int,
23
+ conv_dropout_rate: float,
24
+ fc_dropout_rate: float,
25
+ pooling_strategy: str,
26
+ use_batch_norm: bool,
27
+ num_classes: int = 10,
28
+ in_channels: int = 3,
29
+ input_size: int = 32
30
+ ) -> nn.Sequential:
31
+ '''Create a CNN with configurable architecture.
32
+
33
+ Args:
34
+ n_conv_blocks: Number of convolutional blocks (1-5)
35
+ initial_filters: Number of filters in first conv layer (doubles each block)
36
+ n_fc_layers: Number of fully connected layers (1-8)
37
+ base_kernel_size: Base kernel size (decreases by 2 per block, min 3)
38
+ conv_dropout_rate: Dropout probability after convolutional blocks
39
+ fc_dropout_rate: Dropout probability in fully connected layers
40
+ pooling_strategy: Pooling type ('max' or 'avg')
41
+ use_batch_norm: Whether to use batch normalization
42
+ num_classes: Number of output classes (default: 10 for CIFAR-10)
43
+ in_channels: Number of input channels (default: 3 for RGB)
44
+ input_size: Input image size (default: 32 for CIFAR-10)
45
+
46
+ Returns:
47
+ nn.Sequential model
48
+ '''
49
+ layers = []
50
+ current_channels = in_channels
51
+ current_size = input_size
52
+
53
+ # Convolutional blocks
54
+ for block_idx in range(n_conv_blocks):
55
+ out_channels = initial_filters * (2 ** block_idx)
56
+ kernel_size = max(3, base_kernel_size - 2 * block_idx)
57
+ padding = kernel_size // 2
58
+
59
+ # First conv in block
60
+ layers.append(nn.Conv2d(current_channels, out_channels, kernel_size=kernel_size, padding=padding))
61
+
62
+ if use_batch_norm:
63
+ layers.append(nn.BatchNorm2d(out_channels))
64
+
65
+ layers.append(nn.ReLU())
66
+
67
+ # Second conv in block
68
+ layers.append(nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding))
69
+
70
+ if use_batch_norm:
71
+ layers.append(nn.BatchNorm2d(out_channels))
72
+
73
+ layers.append(nn.ReLU())
74
+
75
+ # Pooling
76
+ if pooling_strategy == 'max':
77
+ layers.append(nn.MaxPool2d(2, 2))
78
+ else: # avg
79
+ layers.append(nn.AvgPool2d(2, 2))
80
+
81
+ layers.append(nn.Dropout(conv_dropout_rate))
82
+
83
+ current_channels = out_channels
84
+ current_size //= 2
85
+
86
+ # Calculate flattened size
87
+ final_channels = initial_filters * (2 ** (n_conv_blocks - 1))
88
+ flattened_size = final_channels * current_size * current_size
89
+
90
+ # Classifier - dynamic FC layers with halving pattern
91
+ layers.append(nn.Flatten())
92
+
93
+ # Generate FC layer sizes by halving from flattened_size
94
+ fc_sizes = []
95
+ current_fc_size = flattened_size // 2
96
+ for _ in range(n_fc_layers):
97
+ fc_sizes.append(max(10, current_fc_size)) # Minimum 10 units
98
+ current_fc_size //= 2
99
+
100
+ # Add FC layers
101
+ in_features = flattened_size
102
+ for fc_size in fc_sizes:
103
+ layers.append(nn.Linear(in_features, fc_size))
104
+ layers.append(nn.ReLU())
105
+ layers.append(nn.Dropout(fc_dropout_rate))
106
+ in_features = fc_size
107
+
108
+ # Output layer
109
+ layers.append(nn.Linear(in_features, num_classes))
110
+
111
+ return nn.Sequential(*layers)
112
+
113
+
114
+ def train_trial(
115
+ model: nn.Module,
116
+ optimizer: optim.Optimizer,
117
+ criterion: nn.Module,
118
+ train_loader: DataLoader,
119
+ val_loader: DataLoader,
120
+ n_epochs: int,
121
+ trial: optuna.Trial
122
+ ) -> float:
123
+ '''Train a model for a single Optuna trial with pruning support.
124
+
125
+ Args:
126
+ model: PyTorch model to train
127
+ optimizer: Optimizer for training
128
+ criterion: Loss function
129
+ train_loader: DataLoader for training data
130
+ val_loader: DataLoader for validation data
131
+ n_epochs: Number of epochs to train
132
+ trial: Optuna trial object for reporting and pruning
133
+
134
+ Returns:
135
+ Best validation accuracy achieved during training
136
+ '''
137
+ best_val_accuracy = 0.0
138
+
139
+ for epoch in range(n_epochs):
140
+
141
+ # Training phase
142
+ model.train()
143
+
144
+ for images, labels in train_loader:
145
+ optimizer.zero_grad()
146
+ outputs = model(images)
147
+ loss = criterion(outputs, labels)
148
+ loss.backward()
149
+ optimizer.step()
150
+
151
+ # Validation phase
152
+ model.eval()
153
+ val_correct = 0
154
+ val_total = 0
155
+
156
+ with torch.no_grad():
157
+ for images, labels in val_loader:
158
+ outputs = model(images)
159
+ _, predicted = torch.max(outputs.data, 1)
160
+ val_total += labels.size(0)
161
+ val_correct += (predicted == labels).sum().item()
162
+
163
+ val_accuracy = 100 * val_correct / val_total
164
+ best_val_accuracy = max(best_val_accuracy, val_accuracy)
165
+
166
+ # Report intermediate value for pruning
167
+ trial.report(val_accuracy, epoch)
168
+
169
+ # Prune unpromising trials
170
+ if trial.should_prune():
171
+ raise optuna.TrialPruned()
172
+
173
+ return best_val_accuracy
174
+
175
+
176
+ def create_objective(
177
+ data_dir,
178
+ train_transform,
179
+ eval_transform,
180
+ n_epochs: int,
181
+ device: torch.device,
182
+ num_classes: int = 10,
183
+ in_channels: int = 3
184
+ ) -> Callable[[optuna.Trial], float]:
185
+ '''Create an Optuna objective function for CNN hyperparameter optimization.
186
+
187
+ This factory function creates a closure that captures the data loading parameters
188
+ and training configuration, returning an objective function suitable for Optuna.
189
+
190
+ Args:
191
+ data_dir: Directory containing CIFAR-10 data
192
+ train_transform: Transform to apply to training data
193
+ eval_transform: Transform to apply to validation data
194
+ n_epochs: Number of epochs per trial
195
+ device: Device to train on (cuda or cpu)
196
+ num_classes: Number of output classes (default: 10)
197
+ in_channels: Number of input channels (default: 3 for RGB)
198
+
199
+ Returns:
200
+ Objective function for optuna.Study.optimize()
201
+
202
+ Example:
203
+ >>> objective = create_objective(data_dir, transform, transform, n_epochs=50, device=device)
204
+ >>> study = optuna.create_study(direction='maximize')
205
+ >>> study.optimize(objective, n_trials=100)
206
+ '''
207
+
208
+ def objective(trial: optuna.Trial) -> float:
209
+ '''Optuna objective function for CNN hyperparameter optimization.'''
210
+
211
+ # Suggest hyperparameters
212
+ batch_size = trial.suggest_categorical('batch_size', [64, 128, 256, 512, 1024])
213
+ n_conv_blocks = trial.suggest_int('n_conv_blocks', 1, 5)
214
+ initial_filters = trial.suggest_categorical('initial_filters', [8, 16, 32, 64, 128])
215
+ n_fc_layers = trial.suggest_int('n_fc_layers', 1, 8)
216
+ base_kernel_size = trial.suggest_int('base_kernel_size', 3, 7)
217
+ conv_dropout_rate = trial.suggest_float('conv_dropout_rate', 0.0, 0.5)
218
+ fc_dropout_rate = trial.suggest_float('fc_dropout_rate', 0.2, 0.75)
219
+ pooling_strategy = trial.suggest_categorical('pooling_strategy', ['max', 'avg'])
220
+ use_batch_norm = trial.suggest_categorical('use_batch_norm', [True, False])
221
+ learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-1, log=True)
222
+ optimizer_name = trial.suggest_categorical('optimizer', ['Adam', 'SGD', 'RMSprop'])
223
+
224
+ # Create data loaders with suggested batch size
225
+ train_loader, val_loader, _ = make_data_loaders(
226
+ data_dir=data_dir,
227
+ batch_size=batch_size,
228
+ train_transform=train_transform,
229
+ eval_transform=eval_transform,
230
+ device=device,
231
+ download=False
232
+ )
233
+
234
+ # Create model
235
+ model = create_cnn(
236
+ n_conv_blocks=n_conv_blocks,
237
+ initial_filters=initial_filters,
238
+ n_fc_layers=n_fc_layers,
239
+ base_kernel_size=base_kernel_size,
240
+ conv_dropout_rate=conv_dropout_rate,
241
+ fc_dropout_rate=fc_dropout_rate,
242
+ pooling_strategy=pooling_strategy,
243
+ use_batch_norm=use_batch_norm,
244
+ num_classes=num_classes,
245
+ in_channels=in_channels
246
+ ).to(device)
247
+
248
+ # Define optimizer
249
+ if optimizer_name == 'Adam':
250
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
251
+
252
+ elif optimizer_name == 'SGD':
253
+ momentum = trial.suggest_float('sgd_momentum', 0.8, 0.99)
254
+ optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
255
+
256
+ else: # RMSprop
257
+ optimizer = optim.RMSprop(model.parameters(), lr=learning_rate)
258
+
259
+ criterion = nn.CrossEntropyLoss()
260
+
261
+ # Train model and return best validation accuracy
262
+ try:
263
+ return train_trial(
264
+ model=model,
265
+ optimizer=optimizer,
266
+ criterion=criterion,
267
+ train_loader=train_loader,
268
+ val_loader=val_loader,
269
+ n_epochs=n_epochs,
270
+ trial=trial
271
+ )
272
+
273
+ except torch.cuda.OutOfMemoryError:
274
+ # Clear CUDA cache and skip this trial
275
+ torch.cuda.empty_cache()
276
+ raise optuna.TrialPruned(f'CUDA OOM with params: {trial.params}')
277
+
278
+ return objective
@@ -10,18 +10,18 @@ def plot_sample_images(
10
10
  class_names: list[str],
11
11
  nrows: int = 2,
12
12
  ncols: int = 5,
13
- figsize: tuple[float, float] | None = None,
14
- cmap: str = 'gray'
13
+ figsize: tuple[float, float] | None = None
15
14
  ) -> tuple[plt.Figure, np.ndarray]:
16
15
  '''Plot sample images from a dataset.
17
16
 
17
+ Automatically handles both grayscale (1 channel) and RGB (3 channel) images.
18
+
18
19
  Args:
19
20
  dataset: PyTorch dataset containing (image, label) tuples.
20
21
  class_names: List of class names for labeling.
21
22
  nrows: Number of rows in the grid.
22
23
  ncols: Number of columns in the grid.
23
24
  figsize: Figure size (width, height). Defaults to (ncols*1.5, nrows*1.5).
24
- cmap: Colormap for displaying images.
25
25
 
26
26
  Returns:
27
27
  Tuple of (figure, axes array).
@@ -36,11 +36,24 @@ def plot_sample_images(
36
36
  # Get image and label from dataset
37
37
  img, label = dataset[i]
38
38
 
39
- # Unnormalize and squeeze for plotting
39
+ # Unnormalize for plotting
40
40
  img = img * 0.5 + 0.5
41
- img = img.numpy().squeeze()
41
+ img = img.numpy()
42
+
43
+ # Handle grayscale vs RGB images
44
+ if img.shape[0] == 1:
45
+
46
+ # Grayscale: squeeze channel dimension
47
+ img = img.squeeze()
48
+ ax.imshow(img, cmap='gray')
49
+
50
+ else:
51
+
52
+ # RGB: transpose from (C, H, W) to (H, W, C)
53
+ img = np.transpose(img, (1, 2, 0))
54
+ ax.imshow(img)
55
+
42
56
  ax.set_title(class_names[label])
43
- ax.imshow(img, cmap=cmap)
44
57
  ax.axis('off')
45
58
 
46
59
  plt.tight_layout()
@@ -209,7 +222,7 @@ def plot_evaluation_curves(
209
222
  roc_auc = auc(fpr, tpr)
210
223
  ax1.plot(fpr, tpr, label=class_name)
211
224
 
212
- ax1.plot([0, 1], [0, 1], 'k--', label='Random classifier')
225
+ ax1.plot([0, 1], [0, 1], 'k--', label='random classifier')
213
226
  ax1.set_xlabel('False positive rate')
214
227
  ax1.set_ylabel('True positive rate')
215
228
  ax1.legend(loc='lower right', fontsize=12)
@@ -235,4 +248,60 @@ def plot_evaluation_curves(
235
248
 
236
249
  plt.tight_layout()
237
250
 
238
- return fig, (ax1, ax2)
251
+ return fig, (ax1, ax2)
252
+
253
+
254
+ def plot_optimization_results(
255
+ study,
256
+ figsize: tuple[float, float] = (12, 4)
257
+ ) -> tuple[plt.Figure, np.ndarray]:
258
+ '''Plot Optuna optimization history and hyperparameter importance.
259
+
260
+ Args:
261
+ study: Optuna study object with completed trials.
262
+ figsize: Figure size (width, height).
263
+
264
+ Returns:
265
+ Tuple of (figure, axes array).
266
+ '''
267
+ import optuna
268
+
269
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
270
+
271
+ # Optimization history
272
+ axes[0].set_title('Optimization History')
273
+
274
+ trial_numbers = [t.number for t in study.trials if t.value is not None]
275
+ trial_values = [t.value for t in study.trials if t.value is not None]
276
+
277
+ axes[0].plot(trial_numbers, trial_values, 'ko-', alpha=0.6)
278
+ axes[0].axhline(
279
+ y=study.best_value,
280
+ color='r', linestyle='--', label=f'Best: {study.best_value:.2f}%'
281
+ )
282
+ axes[0].set_xlabel('Trial')
283
+ axes[0].set_ylabel('Validation Accuracy (%)')
284
+ axes[0].legend()
285
+
286
+ # Hyperparameter importance (if enough trials completed)
287
+ axes[1].set_title('Hyperparameter Importance')
288
+ completed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
289
+
290
+ if len(completed_trials) >= 5:
291
+ importance = optuna.importance.get_param_importances(study)
292
+ params = list(importance.keys())
293
+ values = list(importance.values())
294
+
295
+ axes[1].set_xlabel('Importance')
296
+ axes[1].barh(params, values, color='black')
297
+
298
+ else:
299
+ axes[1].text(
300
+ 0.5, 0.5,
301
+ 'Not enough completed trials\nfor importance analysis',
302
+ ha='center', va='center', transform=axes[1].transAxes
303
+ )
304
+
305
+ plt.tight_layout()
306
+
307
+ return fig, axes
@@ -12,11 +12,28 @@ def train_model(
12
12
  criterion: nn.Module,
13
13
  optimizer: optim.Optimizer,
14
14
  epochs: int = 10,
15
- print_every: int = 1
15
+ print_every: int = 1,
16
+ device: torch.device | str | None = None
16
17
  ) -> dict[str, list[float]]:
17
18
  '''Training loop for PyTorch classification model.
18
19
 
19
- Note: Assumes data is already on the correct device.
20
+ Handles both pre-loaded GPU data and lazy-loading (CPU data moved per-batch).
21
+
22
+ Args:
23
+ model: PyTorch model to train.
24
+ train_loader: DataLoader for training data.
25
+ val_loader: DataLoader for validation data.
26
+ criterion: Loss function.
27
+ optimizer: Optimizer.
28
+ epochs: Number of training epochs.
29
+ print_every: Print progress every n epochs.
30
+ device: Device to move data to per-batch. If None, assumes data is
31
+ already on the correct device (GPU pre-loading). If specified,
32
+ data will be moved to this device per-batch (lazy loading).
33
+
34
+ Returns:
35
+ Dictionary containing training history with keys:
36
+ 'train_loss', 'val_loss', 'train_accuracy', 'val_accuracy'.
20
37
  '''
21
38
 
22
39
  history = {'train_loss': [], 'val_loss': [], 'train_accuracy': [], 'val_accuracy': []}
@@ -30,6 +47,11 @@ def train_model(
30
47
  total = 0
31
48
 
32
49
  for images, labels in train_loader:
50
+
51
+ # Move to device if lazy loading
52
+ if device is not None:
53
+ images = images.to(device, non_blocking=True)
54
+ labels = labels.to(device, non_blocking=True)
33
55
 
34
56
  # Forward pass
35
57
  optimizer.zero_grad()
@@ -59,6 +81,11 @@ def train_model(
59
81
  with torch.no_grad():
60
82
 
61
83
  for images, labels in val_loader:
84
+
85
+ # Move to device if lazy loading
86
+ if device is not None:
87
+ images = images.to(device, non_blocking=True)
88
+ labels = labels.to(device, non_blocking=True)
62
89
 
63
90
  outputs = model(images)
64
91
  loss = criterion(outputs, labels)
@@ -88,6 +115,4 @@ def train_model(
88
115
  f'val_accuracy: {val_accuracy:.2f}%'
89
116
  )
90
117
 
91
- print('\nTraining complete.')
92
-
93
118
  return history
@@ -1,35 +0,0 @@
1
- Metadata-Version: 2.4
2
- Name: cifar10_tools
3
- Version: 0.2.0
4
- Summary: Tools for training neural networks on the CIFAR-10 task with PyTorch and TensorFlow
5
- License: GPLv3
6
- License-File: LICENSE
7
- Keywords: Python,Machine learning,Deep learning,CNNs,Computer vision,Image classification,CIFAR-10
8
- Author: gperdrizet
9
- Author-email: george@perdrizet.org
10
- Requires-Python: >=3.10,<3.13
11
- Classifier: Development Status :: 3 - Alpha
12
- Classifier: Intended Audience :: Developers
13
- Classifier: Intended Audience :: Education
14
- Classifier: Intended Audience :: Science/Research
15
- Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
16
- Classifier: License :: Other/Proprietary License
17
- Classifier: Operating System :: OS Independent
18
- Classifier: Programming Language :: Python :: 3
19
- Classifier: Programming Language :: Python :: 3.10
20
- Classifier: Programming Language :: Python :: 3.11
21
- Classifier: Programming Language :: Python :: 3.12
22
- Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
23
- Classifier: Topic :: Scientific/Engineering :: Image Recognition
24
- Provides-Extra: tensorflow
25
- Requires-Dist: numpy (>=1.24)
26
- Requires-Dist: torch (>=2.0)
27
- Requires-Dist: torchvision (>=0.15)
28
- Project-URL: Documentation, https://gperdrizet.github.io/CIFAR10/README.md
29
- Project-URL: Homepage, https://github.com/gperdrizet/CIFAR10
30
- Project-URL: Issues, https://github.com/gperdrizet/CIFAR10/issues
31
- Project-URL: PyPI, https://pypi.org/project/cifar10_tools
32
- Project-URL: Repository, https://github.com/gperdrizet/CIFAR10
33
- Description-Content-Type: text/markdown
34
-
35
- # PyTorch: CIFAR10 demonstration
@@ -1 +0,0 @@
1
- # PyTorch: CIFAR10 demonstration
@@ -1,27 +0,0 @@
1
- '''Data download function for CIFAR-10 dataset. Use to pre-download data
2
- during devcontainer creation'''
3
-
4
- from pathlib import Path
5
- from torchvision import datasets
6
-
7
- def download_cifar10_data(data_dir: str='data/pytorch/cifar10'):
8
- '''Download CIFAR-10 dataset using torchvision.datasets.'''
9
-
10
- data_dir = Path(data_dir)
11
- data_dir.mkdir(parents=True, exist_ok=True)
12
-
13
- _ = datasets.CIFAR10(
14
- root=data_dir,
15
- train=True,
16
- download=True
17
- )
18
-
19
- _ = datasets.CIFAR10(
20
- root=data_dir,
21
- train=False,
22
- download=True
23
- )
24
-
25
- if __name__ == '__main__':
26
-
27
- download_cifar10_data()
File without changes