cifar10-tools 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,33 @@
1
+ '''PyTorch utilities for CIFAR-10 classification.'''
2
+
3
+ from cifar10_tools.pytorch.data import download_cifar10_data
4
+ from cifar10_tools.pytorch.evaluation import evaluate_model
5
+ from cifar10_tools.pytorch.training import train_model
6
+ from cifar10_tools.pytorch.plotting import (
7
+ plot_sample_images,
8
+ plot_learning_curves,
9
+ plot_confusion_matrix,
10
+ plot_class_probability_distributions,
11
+ plot_evaluation_curves,
12
+ plot_optimization_results
13
+ )
14
+ from cifar10_tools.pytorch.hyperparameter_optimization import (
15
+ create_cnn,
16
+ train_trial,
17
+ create_objective
18
+ )
19
+
20
+ __all__ = [
21
+ 'download_cifar10_data',
22
+ 'evaluate_model',
23
+ 'train_model',
24
+ 'plot_sample_images',
25
+ 'plot_learning_curves',
26
+ 'plot_confusion_matrix',
27
+ 'plot_class_probability_distributions',
28
+ 'plot_evaluation_curves',
29
+ 'plot_optimization_results',
30
+ 'create_cnn',
31
+ 'train_trial',
32
+ 'create_objective'
33
+ ]
@@ -4,7 +4,7 @@ during devcontainer creation'''
4
4
  from pathlib import Path
5
5
  from torchvision import datasets
6
6
 
7
- def download_cifar10_data(data_dir: str='data/pytorch/CIFAR10'):
7
+ def download_cifar10_data(data_dir: str='data/pytorch/cifar10'):
8
8
  '''Download CIFAR-10 dataset using torchvision.datasets.'''
9
9
 
10
10
  data_dir = Path(data_dir)
@@ -0,0 +1,236 @@
1
+ '''Hyperparameter optimization utilities for CNN models using Optuna.
2
+
3
+ This module provides functions for building configurable CNN architectures
4
+ and running hyperparameter optimization with Optuna.
5
+ '''
6
+
7
+ from typing import Callable
8
+
9
+ import optuna
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.optim as optim
13
+ from torch.utils.data import DataLoader
14
+
15
+
16
+ def create_cnn(
17
+ n_conv_blocks: int,
18
+ initial_filters: int,
19
+ fc_units_1: int,
20
+ fc_units_2: int,
21
+ dropout_rate: float,
22
+ use_batch_norm: bool,
23
+ num_classes: int = 10,
24
+ in_channels: int = 3,
25
+ input_size: int = 32
26
+ ) -> nn.Sequential:
27
+ '''Create a CNN with configurable architecture.
28
+
29
+ Args:
30
+ n_conv_blocks: Number of convolutional blocks (1-5)
31
+ initial_filters: Number of filters in first conv layer (doubles each block)
32
+ fc_units_1: Number of units in first fully connected layer
33
+ fc_units_2: Number of units in second fully connected layer
34
+ dropout_rate: Dropout probability
35
+ use_batch_norm: Whether to use batch normalization
36
+ num_classes: Number of output classes (default: 10 for CIFAR-10)
37
+ in_channels: Number of input channels (default: 3 for RGB)
38
+ input_size: Input image size (default: 32 for CIFAR-10)
39
+
40
+ Returns:
41
+ nn.Sequential model
42
+ '''
43
+ layers = []
44
+ current_channels = in_channels
45
+ current_size = input_size
46
+
47
+ for block_idx in range(n_conv_blocks):
48
+ out_channels = initial_filters * (2 ** block_idx)
49
+
50
+ # First conv in block
51
+ layers.append(nn.Conv2d(current_channels, out_channels, kernel_size=3, padding=1))
52
+
53
+ if use_batch_norm:
54
+ layers.append(nn.BatchNorm2d(out_channels))
55
+
56
+ layers.append(nn.ReLU())
57
+
58
+ # Second conv in block
59
+ layers.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
60
+
61
+ if use_batch_norm:
62
+ layers.append(nn.BatchNorm2d(out_channels))
63
+
64
+ layers.append(nn.ReLU())
65
+
66
+ # Pooling and dropout
67
+ layers.append(nn.MaxPool2d(2, 2))
68
+ layers.append(nn.Dropout(dropout_rate))
69
+
70
+ current_channels = out_channels
71
+ current_size //= 2
72
+
73
+ # Calculate flattened size
74
+ final_channels = initial_filters * (2 ** (n_conv_blocks - 1))
75
+ flattened_size = final_channels * current_size * current_size
76
+
77
+ # Classifier (3 fully connected layers)
78
+ layers.append(nn.Flatten())
79
+ layers.append(nn.Linear(flattened_size, fc_units_1))
80
+ layers.append(nn.ReLU())
81
+ layers.append(nn.Dropout(dropout_rate))
82
+ layers.append(nn.Linear(fc_units_1, fc_units_2))
83
+ layers.append(nn.ReLU())
84
+ layers.append(nn.Dropout(dropout_rate))
85
+ layers.append(nn.Linear(fc_units_2, num_classes))
86
+
87
+ return nn.Sequential(*layers)
88
+
89
+
90
+ def train_trial(
91
+ model: nn.Module,
92
+ optimizer: optim.Optimizer,
93
+ criterion: nn.Module,
94
+ train_loader: DataLoader,
95
+ val_loader: DataLoader,
96
+ n_epochs: int,
97
+ trial: optuna.Trial
98
+ ) -> float:
99
+ '''Train a model for a single Optuna trial with pruning support.
100
+
101
+ Args:
102
+ model: PyTorch model to train
103
+ optimizer: Optimizer for training
104
+ criterion: Loss function
105
+ train_loader: DataLoader for training data
106
+ val_loader: DataLoader for validation data
107
+ n_epochs: Number of epochs to train
108
+ trial: Optuna trial object for reporting and pruning
109
+
110
+ Returns:
111
+ Best validation accuracy achieved during training
112
+ '''
113
+ best_val_accuracy = 0.0
114
+
115
+ for epoch in range(n_epochs):
116
+ # Training phase
117
+ model.train()
118
+
119
+ for images, labels in train_loader:
120
+ optimizer.zero_grad()
121
+ outputs = model(images)
122
+ loss = criterion(outputs, labels)
123
+ loss.backward()
124
+ optimizer.step()
125
+
126
+ # Validation phase
127
+ model.eval()
128
+ val_correct = 0
129
+ val_total = 0
130
+
131
+ with torch.no_grad():
132
+ for images, labels in val_loader:
133
+ outputs = model(images)
134
+ _, predicted = torch.max(outputs.data, 1)
135
+ val_total += labels.size(0)
136
+ val_correct += (predicted == labels).sum().item()
137
+
138
+ val_accuracy = 100 * val_correct / val_total
139
+ best_val_accuracy = max(best_val_accuracy, val_accuracy)
140
+
141
+ # Report intermediate value for pruning
142
+ trial.report(val_accuracy, epoch)
143
+
144
+ # Prune unpromising trials
145
+ if trial.should_prune():
146
+ raise optuna.TrialPruned()
147
+
148
+ return best_val_accuracy
149
+
150
+
151
+ def create_objective(
152
+ train_loader: DataLoader,
153
+ val_loader: DataLoader,
154
+ n_epochs: int,
155
+ device: torch.device,
156
+ num_classes: int = 10,
157
+ in_channels: int = 3
158
+ ) -> Callable[[optuna.Trial], float]:
159
+ '''Create an Optuna objective function for CNN hyperparameter optimization.
160
+
161
+ This factory function creates a closure that captures the data loaders and
162
+ training configuration, returning an objective function suitable for Optuna.
163
+
164
+ Args:
165
+ train_loader: DataLoader for training data
166
+ val_loader: DataLoader for validation data
167
+ n_epochs: Number of epochs per trial
168
+ device: Device to train on (cuda or cpu)
169
+ num_classes: Number of output classes (default: 10)
170
+ in_channels: Number of input channels (default: 3 for RGB)
171
+
172
+ Returns:
173
+ Objective function for optuna.Study.optimize()
174
+
175
+ Example:
176
+ >>> objective = create_objective(train_loader, val_loader, n_epochs=50, device=device)
177
+ >>> study = optuna.create_study(direction='maximize')
178
+ >>> study.optimize(objective, n_trials=100)
179
+ '''
180
+
181
+ def objective(trial: optuna.Trial) -> float:
182
+ '''Optuna objective function for CNN hyperparameter optimization.'''
183
+
184
+ # Suggest hyperparameters
185
+ n_conv_blocks = trial.suggest_int('n_conv_blocks', 1, 5)
186
+ initial_filters = trial.suggest_categorical('initial_filters', [8, 16, 32, 64, 128])
187
+ fc_units_1 = trial.suggest_categorical('fc_units_1', [128, 256, 512, 1024, 2048])
188
+ fc_units_2 = trial.suggest_categorical('fc_units_2', [32, 64, 128, 256, 512])
189
+ dropout_rate = trial.suggest_float('dropout_rate', 0.2, 0.75)
190
+ use_batch_norm = trial.suggest_categorical('use_batch_norm', [True, False])
191
+ learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-1, log=True)
192
+ optimizer_name = trial.suggest_categorical('optimizer', ['Adam', 'SGD', 'RMSprop'])
193
+
194
+ # Create model
195
+ model = create_cnn(
196
+ n_conv_blocks=n_conv_blocks,
197
+ initial_filters=initial_filters,
198
+ fc_units_1=fc_units_1,
199
+ fc_units_2=fc_units_2,
200
+ dropout_rate=dropout_rate,
201
+ use_batch_norm=use_batch_norm,
202
+ num_classes=num_classes,
203
+ in_channels=in_channels
204
+ ).to(device)
205
+
206
+ # Define optimizer
207
+ if optimizer_name == 'Adam':
208
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
209
+
210
+ elif optimizer_name == 'SGD':
211
+ momentum = trial.suggest_float('sgd_momentum', 0.8, 0.99)
212
+ optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
213
+
214
+ else: # RMSprop
215
+ optimizer = optim.RMSprop(model.parameters(), lr=learning_rate)
216
+
217
+ criterion = nn.CrossEntropyLoss()
218
+
219
+ # Train model and return best validation accuracy
220
+ try:
221
+ return train_trial(
222
+ model=model,
223
+ optimizer=optimizer,
224
+ criterion=criterion,
225
+ train_loader=train_loader,
226
+ val_loader=val_loader,
227
+ n_epochs=n_epochs,
228
+ trial=trial
229
+ )
230
+
231
+ except torch.cuda.OutOfMemoryError:
232
+ # Clear CUDA cache and skip this trial
233
+ torch.cuda.empty_cache()
234
+ raise optuna.TrialPruned(f'CUDA OOM with params: {trial.params}')
235
+
236
+ return objective
@@ -0,0 +1,307 @@
1
+ '''Plotting functions for CIFAR-10 models.'''
2
+
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ from torch.utils.data import Dataset
6
+
7
+
8
+ def plot_sample_images(
9
+ dataset: Dataset,
10
+ class_names: list[str],
11
+ nrows: int = 2,
12
+ ncols: int = 5,
13
+ figsize: tuple[float, float] | None = None
14
+ ) -> tuple[plt.Figure, np.ndarray]:
15
+ '''Plot sample images from a dataset.
16
+
17
+ Automatically handles both grayscale (1 channel) and RGB (3 channel) images.
18
+
19
+ Args:
20
+ dataset: PyTorch dataset containing (image, label) tuples.
21
+ class_names: List of class names for labeling.
22
+ nrows: Number of rows in the grid.
23
+ ncols: Number of columns in the grid.
24
+ figsize: Figure size (width, height). Defaults to (ncols*1.5, nrows*1.5).
25
+
26
+ Returns:
27
+ Tuple of (figure, axes array).
28
+ '''
29
+ if figsize is None:
30
+ figsize = (ncols * 1.5, nrows * 1.5)
31
+
32
+ fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
33
+ axes = axes.flatten()
34
+
35
+ for i, ax in enumerate(axes):
36
+ # Get image and label from dataset
37
+ img, label = dataset[i]
38
+
39
+ # Unnormalize for plotting
40
+ img = img * 0.5 + 0.5
41
+ img = img.numpy()
42
+
43
+ # Handle grayscale vs RGB images
44
+ if img.shape[0] == 1:
45
+
46
+ # Grayscale: squeeze channel dimension
47
+ img = img.squeeze()
48
+ ax.imshow(img, cmap='gray')
49
+
50
+ else:
51
+
52
+ # RGB: transpose from (C, H, W) to (H, W, C)
53
+ img = np.transpose(img, (1, 2, 0))
54
+ ax.imshow(img)
55
+
56
+ ax.set_title(class_names[label])
57
+ ax.axis('off')
58
+
59
+ plt.tight_layout()
60
+
61
+ return fig, axes
62
+
63
+
64
+ def plot_learning_curves(
65
+ history: dict[str, list[float]],
66
+ figsize: tuple[float, float] = (10, 4)
67
+ ) -> tuple[plt.Figure, np.ndarray]:
68
+ '''Plot training and validation loss and accuracy curves.
69
+
70
+ Args:
71
+ history: Dictionary containing 'train_loss', 'val_loss',
72
+ 'train_accuracy', and 'val_accuracy' lists.
73
+ figsize: Figure size (width, height).
74
+
75
+ Returns:
76
+ Tuple of (figure, axes array).
77
+ '''
78
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
79
+
80
+ axes[0].set_title('Loss')
81
+ axes[0].plot(history['train_loss'], label='Train')
82
+ axes[0].plot(history['val_loss'], label='Validation')
83
+ axes[0].set_xlabel('Epoch')
84
+ axes[0].set_ylabel('Loss (cross-entropy)')
85
+ axes[0].legend(loc='best')
86
+
87
+ axes[1].set_title('Accuracy')
88
+ axes[1].plot(history['train_accuracy'], label='Train')
89
+ axes[1].plot(history['val_accuracy'], label='Validation')
90
+ axes[1].set_xlabel('Epoch')
91
+ axes[1].set_ylabel('Accuracy (%)')
92
+ axes[1].legend(loc='best')
93
+
94
+ plt.tight_layout()
95
+
96
+ return fig, axes
97
+
98
+
99
+ def plot_confusion_matrix(
100
+ true_labels: np.ndarray,
101
+ predictions: np.ndarray,
102
+ class_names: list[str],
103
+ figsize: tuple[float, float] = (8, 8),
104
+ cmap: str = 'Blues'
105
+ ) -> tuple[plt.Figure, plt.Axes]:
106
+ '''Plot a confusion matrix heatmap.
107
+
108
+ Args:
109
+ true_labels: Array of true class labels.
110
+ predictions: Array of predicted class labels.
111
+ class_names: List of class names for labeling.
112
+ figsize: Figure size (width, height).
113
+ cmap: Colormap for the heatmap.
114
+
115
+ Returns:
116
+ Tuple of (figure, axes).
117
+ '''
118
+ from sklearn.metrics import confusion_matrix
119
+
120
+ cm = confusion_matrix(true_labels, predictions)
121
+
122
+ fig, ax = plt.subplots(figsize=figsize)
123
+
124
+ ax.set_title('Confusion matrix')
125
+ im = ax.imshow(cm, cmap=cmap)
126
+
127
+ # Add labels
128
+ ax.set_xticks(range(len(class_names)))
129
+ ax.set_yticks(range(len(class_names)))
130
+ ax.set_xticklabels(class_names, rotation=45, ha='right')
131
+ ax.set_yticklabels(class_names)
132
+ ax.set_xlabel('Predicted label')
133
+ ax.set_ylabel('True label')
134
+
135
+ # Add text annotations
136
+ for i in range(len(class_names)):
137
+ for j in range(len(class_names)):
138
+ color = 'white' if cm[i, j] > cm.max() / 2 else 'black'
139
+ ax.text(j, i, str(cm[i, j]), ha='center', va='center', color=color)
140
+
141
+ plt.tight_layout()
142
+
143
+ return fig, ax
144
+
145
+
146
+ def plot_class_probability_distributions(
147
+ all_probs: np.ndarray,
148
+ class_names: list[str],
149
+ nrows: int = 2,
150
+ ncols: int = 5,
151
+ figsize: tuple[float, float] = (12, 4),
152
+ bins: int = 50,
153
+ color: str = 'black'
154
+ ) -> tuple[plt.Figure, np.ndarray]:
155
+ '''Plot predicted probability distributions for each class.
156
+
157
+ Args:
158
+ all_probs: Array of shape (n_samples, n_classes) with predicted probabilities.
159
+ class_names: List of class names for labeling.
160
+ nrows: Number of rows in the subplot grid.
161
+ ncols: Number of columns in the subplot grid.
162
+ figsize: Figure size (width, height).
163
+ bins: Number of histogram bins.
164
+ color: Histogram bar color.
165
+
166
+ Returns:
167
+ Tuple of (figure, axes array).
168
+ '''
169
+ fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
170
+
171
+ fig.suptitle('Predicted probability distributions by class', fontsize=14, y=1.02)
172
+ fig.supxlabel('Predicted probability', fontsize=12)
173
+ fig.supylabel('Count', fontsize=12)
174
+
175
+ axes = axes.flatten()
176
+
177
+ for i, (ax, class_name) in enumerate(zip(axes, class_names)):
178
+ # Get probabilities for this class across all samples
179
+ class_probs = all_probs[:, i]
180
+
181
+ # Plot histogram
182
+ ax.hist(class_probs, bins=bins, color=color)
183
+ ax.set_title(class_name)
184
+ ax.set_xlim(0, 1)
185
+
186
+ plt.tight_layout()
187
+
188
+ return fig, axes
189
+
190
+
191
+ def plot_evaluation_curves(
192
+ true_labels: np.ndarray,
193
+ all_probs: np.ndarray,
194
+ class_names: list[str],
195
+ figsize: tuple[float, float] = (12, 5)
196
+ ) -> tuple[plt.Figure, tuple[plt.Axes, plt.Axes]]:
197
+ '''Plot ROC and Precision-Recall curves for multi-class classification.
198
+
199
+ Args:
200
+ true_labels: Array of true class labels.
201
+ all_probs: Array of shape (n_samples, n_classes) with predicted probabilities.
202
+ class_names: List of class names for labeling.
203
+ figsize: Figure size (width, height).
204
+
205
+ Returns:
206
+ Tuple of (figure, (ax1, ax2)).
207
+ '''
208
+ from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
209
+ from sklearn.preprocessing import label_binarize
210
+
211
+ # Binarize true labels for one-vs-rest evaluation
212
+ y_test_bin = label_binarize(true_labels, classes=range(len(class_names)))
213
+
214
+ # Create figure with ROC and PR curves side by side
215
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
216
+
217
+ # Plot ROC curves for each class
218
+ ax1.set_title('ROC curves (one-vs-rest)')
219
+
220
+ for i, class_name in enumerate(class_names):
221
+ fpr, tpr, _ = roc_curve(y_test_bin[:, i], all_probs[:, i])
222
+ roc_auc = auc(fpr, tpr)
223
+ ax1.plot(fpr, tpr, label=class_name)
224
+
225
+ ax1.plot([0, 1], [0, 1], 'k--', label='Random classifier')
226
+ ax1.set_xlabel('False positive rate')
227
+ ax1.set_ylabel('True positive rate')
228
+ ax1.legend(loc='lower right', fontsize=12)
229
+ ax1.set_xlim([0, 1])
230
+ ax1.set_ylim([0, 1.05])
231
+
232
+ # Plot Precision-Recall curves for each class
233
+ ax2.set_title('Precision-recall curves (one-vs-rest)')
234
+
235
+ for i, class_name in enumerate(class_names):
236
+ precision, recall, _ = precision_recall_curve(y_test_bin[:, i], all_probs[:, i])
237
+ ap = average_precision_score(y_test_bin[:, i], all_probs[:, i])
238
+ ax2.plot(recall, precision)
239
+
240
+ # Random classifier baseline (horizontal line at class prevalence = 1/num_classes)
241
+ baseline = 1 / len(class_names)
242
+ ax2.axhline(y=baseline, color='k', linestyle='--')
243
+
244
+ ax2.set_xlabel('Recall')
245
+ ax2.set_ylabel('Precision')
246
+ ax2.set_xlim([0, 1])
247
+ ax2.set_ylim([0, 1.05])
248
+
249
+ plt.tight_layout()
250
+
251
+ return fig, (ax1, ax2)
252
+
253
+
254
+ def plot_optimization_results(
255
+ study,
256
+ figsize: tuple[float, float] = (12, 4)
257
+ ) -> tuple[plt.Figure, np.ndarray]:
258
+ '''Plot Optuna optimization history and hyperparameter importance.
259
+
260
+ Args:
261
+ study: Optuna study object with completed trials.
262
+ figsize: Figure size (width, height).
263
+
264
+ Returns:
265
+ Tuple of (figure, axes array).
266
+ '''
267
+ import optuna
268
+
269
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
270
+
271
+ # Optimization history
272
+ axes[0].set_title('Optimization History')
273
+
274
+ trial_numbers = [t.number for t in study.trials if t.value is not None]
275
+ trial_values = [t.value for t in study.trials if t.value is not None]
276
+
277
+ axes[0].plot(trial_numbers, trial_values, 'ko-', alpha=0.6)
278
+ axes[0].axhline(
279
+ y=study.best_value,
280
+ color='r', linestyle='--', label=f'Best: {study.best_value:.2f}%'
281
+ )
282
+ axes[0].set_xlabel('Trial')
283
+ axes[0].set_ylabel('Validation Accuracy (%)')
284
+ axes[0].legend()
285
+
286
+ # Hyperparameter importance (if enough trials completed)
287
+ axes[1].set_title('Hyperparameter Importance')
288
+ completed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
289
+
290
+ if len(completed_trials) >= 5:
291
+ importance = optuna.importance.get_param_importances(study)
292
+ params = list(importance.keys())
293
+ values = list(importance.values())
294
+
295
+ axes[1].set_xlabel('Importance')
296
+ axes[1].barh(params, values, color='black')
297
+
298
+ else:
299
+ axes[1].text(
300
+ 0.5, 0.5,
301
+ 'Not enough completed trials\nfor importance analysis',
302
+ ha='center', va='center', transform=axes[1].transAxes
303
+ )
304
+
305
+ plt.tight_layout()
306
+
307
+ return fig, axes
@@ -12,11 +12,28 @@ def train_model(
12
12
  criterion: nn.Module,
13
13
  optimizer: optim.Optimizer,
14
14
  epochs: int = 10,
15
- print_every: int = 1
15
+ print_every: int = 1,
16
+ device: torch.device | str | None = None
16
17
  ) -> dict[str, list[float]]:
17
18
  '''Training loop for PyTorch classification model.
18
19
 
19
- Note: Assumes data is already on the correct device.
20
+ Handles both pre-loaded GPU data and lazy-loading (CPU data moved per-batch).
21
+
22
+ Args:
23
+ model: PyTorch model to train.
24
+ train_loader: DataLoader for training data.
25
+ val_loader: DataLoader for validation data.
26
+ criterion: Loss function.
27
+ optimizer: Optimizer.
28
+ epochs: Number of training epochs.
29
+ print_every: Print progress every n epochs.
30
+ device: Device to move data to per-batch. If None, assumes data is
31
+ already on the correct device (GPU pre-loading). If specified,
32
+ data will be moved to this device per-batch (lazy loading).
33
+
34
+ Returns:
35
+ Dictionary containing training history with keys:
36
+ 'train_loss', 'val_loss', 'train_accuracy', 'val_accuracy'.
20
37
  '''
21
38
 
22
39
  history = {'train_loss': [], 'val_loss': [], 'train_accuracy': [], 'val_accuracy': []}
@@ -30,6 +47,11 @@ def train_model(
30
47
  total = 0
31
48
 
32
49
  for images, labels in train_loader:
50
+
51
+ # Move to device if lazy loading
52
+ if device is not None:
53
+ images = images.to(device, non_blocking=True)
54
+ labels = labels.to(device, non_blocking=True)
33
55
 
34
56
  # Forward pass
35
57
  optimizer.zero_grad()
@@ -59,6 +81,11 @@ def train_model(
59
81
  with torch.no_grad():
60
82
 
61
83
  for images, labels in val_loader:
84
+
85
+ # Move to device if lazy loading
86
+ if device is not None:
87
+ images = images.to(device, non_blocking=True)
88
+ labels = labels.to(device, non_blocking=True)
62
89
 
63
90
  outputs = model(images)
64
91
  loss = criterion(outputs, labels)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cifar10_tools
3
- Version: 0.1.0
3
+ Version: 0.3.0
4
4
  Summary: Tools for training neural networks on the CIFAR-10 task with PyTorch and TensorFlow
5
5
  License: GPLv3
6
6
  License-File: LICENSE
@@ -0,0 +1,12 @@
1
+ cifar10_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ cifar10_tools/pytorch/__init__.py,sha256=4er-aMGK-MZTlkH3Owz3x-Pz_Gl_NjplKwOBYdBA1p0,909
3
+ cifar10_tools/pytorch/data.py,sha256=09zodpjto0xLq95tDAyq57CFh6MSYRuUBPcMmQcyKZM,626
4
+ cifar10_tools/pytorch/evaluation.py,sha256=i4tRYOqWATVqQVkWT_fATWRbzo9ziX2DDkXKPaiQlFE,923
5
+ cifar10_tools/pytorch/hyperparameter_optimization.py,sha256=92MwDp6CarFp6O-tkJqeVqDyn0Az15gu3pluAvnO2mw,8056
6
+ cifar10_tools/pytorch/plotting.py,sha256=9kRDt9ZEX0uOUlt-9wzJHrx4WELuFYMeeQiJrmwyXNs,9550
7
+ cifar10_tools/pytorch/training.py,sha256=KNaH-Q9u61o3DIcTfBhjnOvOD7yExZeXwBm6qvMGL9I,3859
8
+ cifar10_tools/tensorflow/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
+ cifar10_tools-0.3.0.dist-info/METADATA,sha256=Ll6YMa77t9ubJLaiFF8BsMmDuj_pzTLejL6Wlje2Qwo,1580
10
+ cifar10_tools-0.3.0.dist-info/WHEEL,sha256=kJCRJT_g0adfAJzTx2GUMmS80rTJIVHRCfG0DQgLq3o,88
11
+ cifar10_tools-0.3.0.dist-info/licenses/LICENSE,sha256=wtHfRwmCF5-_XUmYwrBKwJkGipvHVmh7GXJOKKeOe2U,1073
12
+ cifar10_tools-0.3.0.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- cifar10_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- cifar10_tools/pytorch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- cifar10_tools/pytorch/data.py,sha256=zEDdRbcCHehDg5mdOGDopKT-uCRTjF27Q_UYTAPVEhQ,626
4
- cifar10_tools/pytorch/evaluation.py,sha256=i4tRYOqWATVqQVkWT_fATWRbzo9ziX2DDkXKPaiQlFE,923
5
- cifar10_tools/pytorch/training.py,sha256=Sg6NlBT_DTyLzf-Ls3bYI8-8AwGFJblRj0MDnUmGP3Q,2642
6
- cifar10_tools/tensorflow/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
- cifar10_tools-0.1.0.dist-info/METADATA,sha256=3wdozzaT9e9M6Tf5c7EbiJ8XVXewrJiXTHxGQxhMJ0Q,1580
8
- cifar10_tools-0.1.0.dist-info/WHEEL,sha256=kJCRJT_g0adfAJzTx2GUMmS80rTJIVHRCfG0DQgLq3o,88
9
- cifar10_tools-0.1.0.dist-info/licenses/LICENSE,sha256=wtHfRwmCF5-_XUmYwrBKwJkGipvHVmh7GXJOKKeOe2U,1073
10
- cifar10_tools-0.1.0.dist-info/RECORD,,