cifar10-tools 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,33 @@
1
+ '''PyTorch utilities for CIFAR-10 classification.'''
2
+
3
+ from cifar10_tools.pytorch.data import download_cifar10_data
4
+ from cifar10_tools.pytorch.evaluation import evaluate_model
5
+ from cifar10_tools.pytorch.training import train_model
6
+ from cifar10_tools.pytorch.plotting import (
7
+ plot_sample_images,
8
+ plot_learning_curves,
9
+ plot_confusion_matrix,
10
+ plot_class_probability_distributions,
11
+ plot_evaluation_curves,
12
+ plot_optimization_results
13
+ )
14
+ from cifar10_tools.pytorch.hyperparameter_optimization import (
15
+ create_cnn,
16
+ train_trial,
17
+ create_objective
18
+ )
19
+
20
+ __all__ = [
21
+ 'download_cifar10_data',
22
+ 'evaluate_model',
23
+ 'train_model',
24
+ 'plot_sample_images',
25
+ 'plot_learning_curves',
26
+ 'plot_confusion_matrix',
27
+ 'plot_class_probability_distributions',
28
+ 'plot_evaluation_curves',
29
+ 'plot_optimization_results',
30
+ 'create_cnn',
31
+ 'train_trial',
32
+ 'create_objective'
33
+ ]
@@ -0,0 +1,236 @@
1
+ '''Hyperparameter optimization utilities for CNN models using Optuna.
2
+
3
+ This module provides functions for building configurable CNN architectures
4
+ and running hyperparameter optimization with Optuna.
5
+ '''
6
+
7
+ from typing import Callable
8
+
9
+ import optuna
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.optim as optim
13
+ from torch.utils.data import DataLoader
14
+
15
+
16
+ def create_cnn(
17
+ n_conv_blocks: int,
18
+ initial_filters: int,
19
+ fc_units_1: int,
20
+ fc_units_2: int,
21
+ dropout_rate: float,
22
+ use_batch_norm: bool,
23
+ num_classes: int = 10,
24
+ in_channels: int = 3,
25
+ input_size: int = 32
26
+ ) -> nn.Sequential:
27
+ '''Create a CNN with configurable architecture.
28
+
29
+ Args:
30
+ n_conv_blocks: Number of convolutional blocks (1-5)
31
+ initial_filters: Number of filters in first conv layer (doubles each block)
32
+ fc_units_1: Number of units in first fully connected layer
33
+ fc_units_2: Number of units in second fully connected layer
34
+ dropout_rate: Dropout probability
35
+ use_batch_norm: Whether to use batch normalization
36
+ num_classes: Number of output classes (default: 10 for CIFAR-10)
37
+ in_channels: Number of input channels (default: 3 for RGB)
38
+ input_size: Input image size (default: 32 for CIFAR-10)
39
+
40
+ Returns:
41
+ nn.Sequential model
42
+ '''
43
+ layers = []
44
+ current_channels = in_channels
45
+ current_size = input_size
46
+
47
+ for block_idx in range(n_conv_blocks):
48
+ out_channels = initial_filters * (2 ** block_idx)
49
+
50
+ # First conv in block
51
+ layers.append(nn.Conv2d(current_channels, out_channels, kernel_size=3, padding=1))
52
+
53
+ if use_batch_norm:
54
+ layers.append(nn.BatchNorm2d(out_channels))
55
+
56
+ layers.append(nn.ReLU())
57
+
58
+ # Second conv in block
59
+ layers.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
60
+
61
+ if use_batch_norm:
62
+ layers.append(nn.BatchNorm2d(out_channels))
63
+
64
+ layers.append(nn.ReLU())
65
+
66
+ # Pooling and dropout
67
+ layers.append(nn.MaxPool2d(2, 2))
68
+ layers.append(nn.Dropout(dropout_rate))
69
+
70
+ current_channels = out_channels
71
+ current_size //= 2
72
+
73
+ # Calculate flattened size
74
+ final_channels = initial_filters * (2 ** (n_conv_blocks - 1))
75
+ flattened_size = final_channels * current_size * current_size
76
+
77
+ # Classifier (3 fully connected layers)
78
+ layers.append(nn.Flatten())
79
+ layers.append(nn.Linear(flattened_size, fc_units_1))
80
+ layers.append(nn.ReLU())
81
+ layers.append(nn.Dropout(dropout_rate))
82
+ layers.append(nn.Linear(fc_units_1, fc_units_2))
83
+ layers.append(nn.ReLU())
84
+ layers.append(nn.Dropout(dropout_rate))
85
+ layers.append(nn.Linear(fc_units_2, num_classes))
86
+
87
+ return nn.Sequential(*layers)
88
+
89
+
90
+ def train_trial(
91
+ model: nn.Module,
92
+ optimizer: optim.Optimizer,
93
+ criterion: nn.Module,
94
+ train_loader: DataLoader,
95
+ val_loader: DataLoader,
96
+ n_epochs: int,
97
+ trial: optuna.Trial
98
+ ) -> float:
99
+ '''Train a model for a single Optuna trial with pruning support.
100
+
101
+ Args:
102
+ model: PyTorch model to train
103
+ optimizer: Optimizer for training
104
+ criterion: Loss function
105
+ train_loader: DataLoader for training data
106
+ val_loader: DataLoader for validation data
107
+ n_epochs: Number of epochs to train
108
+ trial: Optuna trial object for reporting and pruning
109
+
110
+ Returns:
111
+ Best validation accuracy achieved during training
112
+ '''
113
+ best_val_accuracy = 0.0
114
+
115
+ for epoch in range(n_epochs):
116
+ # Training phase
117
+ model.train()
118
+
119
+ for images, labels in train_loader:
120
+ optimizer.zero_grad()
121
+ outputs = model(images)
122
+ loss = criterion(outputs, labels)
123
+ loss.backward()
124
+ optimizer.step()
125
+
126
+ # Validation phase
127
+ model.eval()
128
+ val_correct = 0
129
+ val_total = 0
130
+
131
+ with torch.no_grad():
132
+ for images, labels in val_loader:
133
+ outputs = model(images)
134
+ _, predicted = torch.max(outputs.data, 1)
135
+ val_total += labels.size(0)
136
+ val_correct += (predicted == labels).sum().item()
137
+
138
+ val_accuracy = 100 * val_correct / val_total
139
+ best_val_accuracy = max(best_val_accuracy, val_accuracy)
140
+
141
+ # Report intermediate value for pruning
142
+ trial.report(val_accuracy, epoch)
143
+
144
+ # Prune unpromising trials
145
+ if trial.should_prune():
146
+ raise optuna.TrialPruned()
147
+
148
+ return best_val_accuracy
149
+
150
+
151
+ def create_objective(
152
+ train_loader: DataLoader,
153
+ val_loader: DataLoader,
154
+ n_epochs: int,
155
+ device: torch.device,
156
+ num_classes: int = 10,
157
+ in_channels: int = 3
158
+ ) -> Callable[[optuna.Trial], float]:
159
+ '''Create an Optuna objective function for CNN hyperparameter optimization.
160
+
161
+ This factory function creates a closure that captures the data loaders and
162
+ training configuration, returning an objective function suitable for Optuna.
163
+
164
+ Args:
165
+ train_loader: DataLoader for training data
166
+ val_loader: DataLoader for validation data
167
+ n_epochs: Number of epochs per trial
168
+ device: Device to train on (cuda or cpu)
169
+ num_classes: Number of output classes (default: 10)
170
+ in_channels: Number of input channels (default: 3 for RGB)
171
+
172
+ Returns:
173
+ Objective function for optuna.Study.optimize()
174
+
175
+ Example:
176
+ >>> objective = create_objective(train_loader, val_loader, n_epochs=50, device=device)
177
+ >>> study = optuna.create_study(direction='maximize')
178
+ >>> study.optimize(objective, n_trials=100)
179
+ '''
180
+
181
+ def objective(trial: optuna.Trial) -> float:
182
+ '''Optuna objective function for CNN hyperparameter optimization.'''
183
+
184
+ # Suggest hyperparameters
185
+ n_conv_blocks = trial.suggest_int('n_conv_blocks', 1, 5)
186
+ initial_filters = trial.suggest_categorical('initial_filters', [8, 16, 32, 64, 128])
187
+ fc_units_1 = trial.suggest_categorical('fc_units_1', [128, 256, 512, 1024, 2048])
188
+ fc_units_2 = trial.suggest_categorical('fc_units_2', [32, 64, 128, 256, 512])
189
+ dropout_rate = trial.suggest_float('dropout_rate', 0.2, 0.75)
190
+ use_batch_norm = trial.suggest_categorical('use_batch_norm', [True, False])
191
+ learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-1, log=True)
192
+ optimizer_name = trial.suggest_categorical('optimizer', ['Adam', 'SGD', 'RMSprop'])
193
+
194
+ # Create model
195
+ model = create_cnn(
196
+ n_conv_blocks=n_conv_blocks,
197
+ initial_filters=initial_filters,
198
+ fc_units_1=fc_units_1,
199
+ fc_units_2=fc_units_2,
200
+ dropout_rate=dropout_rate,
201
+ use_batch_norm=use_batch_norm,
202
+ num_classes=num_classes,
203
+ in_channels=in_channels
204
+ ).to(device)
205
+
206
+ # Define optimizer
207
+ if optimizer_name == 'Adam':
208
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
209
+
210
+ elif optimizer_name == 'SGD':
211
+ momentum = trial.suggest_float('sgd_momentum', 0.8, 0.99)
212
+ optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
213
+
214
+ else: # RMSprop
215
+ optimizer = optim.RMSprop(model.parameters(), lr=learning_rate)
216
+
217
+ criterion = nn.CrossEntropyLoss()
218
+
219
+ # Train model and return best validation accuracy
220
+ try:
221
+ return train_trial(
222
+ model=model,
223
+ optimizer=optimizer,
224
+ criterion=criterion,
225
+ train_loader=train_loader,
226
+ val_loader=val_loader,
227
+ n_epochs=n_epochs,
228
+ trial=trial
229
+ )
230
+
231
+ except torch.cuda.OutOfMemoryError:
232
+ # Clear CUDA cache and skip this trial
233
+ torch.cuda.empty_cache()
234
+ raise optuna.TrialPruned(f'CUDA OOM with params: {trial.params}')
235
+
236
+ return objective
@@ -10,18 +10,18 @@ def plot_sample_images(
10
10
  class_names: list[str],
11
11
  nrows: int = 2,
12
12
  ncols: int = 5,
13
- figsize: tuple[float, float] | None = None,
14
- cmap: str = 'gray'
13
+ figsize: tuple[float, float] | None = None
15
14
  ) -> tuple[plt.Figure, np.ndarray]:
16
15
  '''Plot sample images from a dataset.
17
16
 
17
+ Automatically handles both grayscale (1 channel) and RGB (3 channel) images.
18
+
18
19
  Args:
19
20
  dataset: PyTorch dataset containing (image, label) tuples.
20
21
  class_names: List of class names for labeling.
21
22
  nrows: Number of rows in the grid.
22
23
  ncols: Number of columns in the grid.
23
24
  figsize: Figure size (width, height). Defaults to (ncols*1.5, nrows*1.5).
24
- cmap: Colormap for displaying images.
25
25
 
26
26
  Returns:
27
27
  Tuple of (figure, axes array).
@@ -36,11 +36,24 @@ def plot_sample_images(
36
36
  # Get image and label from dataset
37
37
  img, label = dataset[i]
38
38
 
39
- # Unnormalize and squeeze for plotting
39
+ # Unnormalize for plotting
40
40
  img = img * 0.5 + 0.5
41
- img = img.numpy().squeeze()
41
+ img = img.numpy()
42
+
43
+ # Handle grayscale vs RGB images
44
+ if img.shape[0] == 1:
45
+
46
+ # Grayscale: squeeze channel dimension
47
+ img = img.squeeze()
48
+ ax.imshow(img, cmap='gray')
49
+
50
+ else:
51
+
52
+ # RGB: transpose from (C, H, W) to (H, W, C)
53
+ img = np.transpose(img, (1, 2, 0))
54
+ ax.imshow(img)
55
+
42
56
  ax.set_title(class_names[label])
43
- ax.imshow(img, cmap=cmap)
44
57
  ax.axis('off')
45
58
 
46
59
  plt.tight_layout()
@@ -235,4 +248,60 @@ def plot_evaluation_curves(
235
248
 
236
249
  plt.tight_layout()
237
250
 
238
- return fig, (ax1, ax2)
251
+ return fig, (ax1, ax2)
252
+
253
+
254
+ def plot_optimization_results(
255
+ study,
256
+ figsize: tuple[float, float] = (12, 4)
257
+ ) -> tuple[plt.Figure, np.ndarray]:
258
+ '''Plot Optuna optimization history and hyperparameter importance.
259
+
260
+ Args:
261
+ study: Optuna study object with completed trials.
262
+ figsize: Figure size (width, height).
263
+
264
+ Returns:
265
+ Tuple of (figure, axes array).
266
+ '''
267
+ import optuna
268
+
269
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
270
+
271
+ # Optimization history
272
+ axes[0].set_title('Optimization History')
273
+
274
+ trial_numbers = [t.number for t in study.trials if t.value is not None]
275
+ trial_values = [t.value for t in study.trials if t.value is not None]
276
+
277
+ axes[0].plot(trial_numbers, trial_values, 'ko-', alpha=0.6)
278
+ axes[0].axhline(
279
+ y=study.best_value,
280
+ color='r', linestyle='--', label=f'Best: {study.best_value:.2f}%'
281
+ )
282
+ axes[0].set_xlabel('Trial')
283
+ axes[0].set_ylabel('Validation Accuracy (%)')
284
+ axes[0].legend()
285
+
286
+ # Hyperparameter importance (if enough trials completed)
287
+ axes[1].set_title('Hyperparameter Importance')
288
+ completed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
289
+
290
+ if len(completed_trials) >= 5:
291
+ importance = optuna.importance.get_param_importances(study)
292
+ params = list(importance.keys())
293
+ values = list(importance.values())
294
+
295
+ axes[1].set_xlabel('Importance')
296
+ axes[1].barh(params, values, color='black')
297
+
298
+ else:
299
+ axes[1].text(
300
+ 0.5, 0.5,
301
+ 'Not enough completed trials\nfor importance analysis',
302
+ ha='center', va='center', transform=axes[1].transAxes
303
+ )
304
+
305
+ plt.tight_layout()
306
+
307
+ return fig, axes
@@ -12,11 +12,28 @@ def train_model(
12
12
  criterion: nn.Module,
13
13
  optimizer: optim.Optimizer,
14
14
  epochs: int = 10,
15
- print_every: int = 1
15
+ print_every: int = 1,
16
+ device: torch.device | str | None = None
16
17
  ) -> dict[str, list[float]]:
17
18
  '''Training loop for PyTorch classification model.
18
19
 
19
- Note: Assumes data is already on the correct device.
20
+ Handles both pre-loaded GPU data and lazy-loading (CPU data moved per-batch).
21
+
22
+ Args:
23
+ model: PyTorch model to train.
24
+ train_loader: DataLoader for training data.
25
+ val_loader: DataLoader for validation data.
26
+ criterion: Loss function.
27
+ optimizer: Optimizer.
28
+ epochs: Number of training epochs.
29
+ print_every: Print progress every n epochs.
30
+ device: Device to move data to per-batch. If None, assumes data is
31
+ already on the correct device (GPU pre-loading). If specified,
32
+ data will be moved to this device per-batch (lazy loading).
33
+
34
+ Returns:
35
+ Dictionary containing training history with keys:
36
+ 'train_loss', 'val_loss', 'train_accuracy', 'val_accuracy'.
20
37
  '''
21
38
 
22
39
  history = {'train_loss': [], 'val_loss': [], 'train_accuracy': [], 'val_accuracy': []}
@@ -30,6 +47,11 @@ def train_model(
30
47
  total = 0
31
48
 
32
49
  for images, labels in train_loader:
50
+
51
+ # Move to device if lazy loading
52
+ if device is not None:
53
+ images = images.to(device, non_blocking=True)
54
+ labels = labels.to(device, non_blocking=True)
33
55
 
34
56
  # Forward pass
35
57
  optimizer.zero_grad()
@@ -59,6 +81,11 @@ def train_model(
59
81
  with torch.no_grad():
60
82
 
61
83
  for images, labels in val_loader:
84
+
85
+ # Move to device if lazy loading
86
+ if device is not None:
87
+ images = images.to(device, non_blocking=True)
88
+ labels = labels.to(device, non_blocking=True)
62
89
 
63
90
  outputs = model(images)
64
91
  loss = criterion(outputs, labels)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cifar10_tools
3
- Version: 0.2.0
3
+ Version: 0.3.0
4
4
  Summary: Tools for training neural networks on the CIFAR-10 task with PyTorch and TensorFlow
5
5
  License: GPLv3
6
6
  License-File: LICENSE
@@ -0,0 +1,12 @@
1
+ cifar10_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ cifar10_tools/pytorch/__init__.py,sha256=4er-aMGK-MZTlkH3Owz3x-Pz_Gl_NjplKwOBYdBA1p0,909
3
+ cifar10_tools/pytorch/data.py,sha256=09zodpjto0xLq95tDAyq57CFh6MSYRuUBPcMmQcyKZM,626
4
+ cifar10_tools/pytorch/evaluation.py,sha256=i4tRYOqWATVqQVkWT_fATWRbzo9ziX2DDkXKPaiQlFE,923
5
+ cifar10_tools/pytorch/hyperparameter_optimization.py,sha256=92MwDp6CarFp6O-tkJqeVqDyn0Az15gu3pluAvnO2mw,8056
6
+ cifar10_tools/pytorch/plotting.py,sha256=9kRDt9ZEX0uOUlt-9wzJHrx4WELuFYMeeQiJrmwyXNs,9550
7
+ cifar10_tools/pytorch/training.py,sha256=KNaH-Q9u61o3DIcTfBhjnOvOD7yExZeXwBm6qvMGL9I,3859
8
+ cifar10_tools/tensorflow/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
+ cifar10_tools-0.3.0.dist-info/METADATA,sha256=Ll6YMa77t9ubJLaiFF8BsMmDuj_pzTLejL6Wlje2Qwo,1580
10
+ cifar10_tools-0.3.0.dist-info/WHEEL,sha256=kJCRJT_g0adfAJzTx2GUMmS80rTJIVHRCfG0DQgLq3o,88
11
+ cifar10_tools-0.3.0.dist-info/licenses/LICENSE,sha256=wtHfRwmCF5-_XUmYwrBKwJkGipvHVmh7GXJOKKeOe2U,1073
12
+ cifar10_tools-0.3.0.dist-info/RECORD,,
@@ -1,11 +0,0 @@
1
- cifar10_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- cifar10_tools/pytorch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- cifar10_tools/pytorch/data.py,sha256=09zodpjto0xLq95tDAyq57CFh6MSYRuUBPcMmQcyKZM,626
4
- cifar10_tools/pytorch/evaluation.py,sha256=i4tRYOqWATVqQVkWT_fATWRbzo9ziX2DDkXKPaiQlFE,923
5
- cifar10_tools/pytorch/plotting.py,sha256=B1ifJxbSEDpInnVk9c3o1fjVx534TPPKTWM5iusyzrE,7494
6
- cifar10_tools/pytorch/training.py,sha256=Sg6NlBT_DTyLzf-Ls3bYI8-8AwGFJblRj0MDnUmGP3Q,2642
7
- cifar10_tools/tensorflow/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- cifar10_tools-0.2.0.dist-info/METADATA,sha256=3s6_5lP8rAnEu5F9r5YKU-EqUi9UO3mNUFK1ikVgUfc,1580
9
- cifar10_tools-0.2.0.dist-info/WHEEL,sha256=kJCRJT_g0adfAJzTx2GUMmS80rTJIVHRCfG0DQgLq3o,88
10
- cifar10_tools-0.2.0.dist-info/licenses/LICENSE,sha256=wtHfRwmCF5-_XUmYwrBKwJkGipvHVmh7GXJOKKeOe2U,1073
11
- cifar10_tools-0.2.0.dist-info/RECORD,,