cifar10-tools 0.1.0__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cifar10_tools
3
- Version: 0.1.0
3
+ Version: 0.2.0
4
4
  Summary: Tools for training neural networks on the CIFAR-10 task with PyTorch and TensorFlow
5
5
  License: GPLv3
6
6
  License-File: LICENSE
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "cifar10_tools"
7
- version = "0.1.0"
7
+ version = "0.2.0"
8
8
  description = "Tools for training neural networks on the CIFAR-10 task with PyTorch and TensorFlow"
9
9
  authors = ["gperdrizet <george@perdrizet.org>"]
10
10
  readme = "README.md"
@@ -4,7 +4,7 @@ during devcontainer creation'''
4
4
  from pathlib import Path
5
5
  from torchvision import datasets
6
6
 
7
- def download_cifar10_data(data_dir: str='data/pytorch/CIFAR10'):
7
+ def download_cifar10_data(data_dir: str='data/pytorch/cifar10'):
8
8
  '''Download CIFAR-10 dataset using torchvision.datasets.'''
9
9
 
10
10
  data_dir = Path(data_dir)
@@ -0,0 +1,238 @@
1
+ '''Plotting functions for CIFAR-10 models.'''
2
+
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ from torch.utils.data import Dataset
6
+
7
+
8
+ def plot_sample_images(
9
+ dataset: Dataset,
10
+ class_names: list[str],
11
+ nrows: int = 2,
12
+ ncols: int = 5,
13
+ figsize: tuple[float, float] | None = None,
14
+ cmap: str = 'gray'
15
+ ) -> tuple[plt.Figure, np.ndarray]:
16
+ '''Plot sample images from a dataset.
17
+
18
+ Args:
19
+ dataset: PyTorch dataset containing (image, label) tuples.
20
+ class_names: List of class names for labeling.
21
+ nrows: Number of rows in the grid.
22
+ ncols: Number of columns in the grid.
23
+ figsize: Figure size (width, height). Defaults to (ncols*1.5, nrows*1.5).
24
+ cmap: Colormap for displaying images.
25
+
26
+ Returns:
27
+ Tuple of (figure, axes array).
28
+ '''
29
+ if figsize is None:
30
+ figsize = (ncols * 1.5, nrows * 1.5)
31
+
32
+ fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
33
+ axes = axes.flatten()
34
+
35
+ for i, ax in enumerate(axes):
36
+ # Get image and label from dataset
37
+ img, label = dataset[i]
38
+
39
+ # Unnormalize and squeeze for plotting
40
+ img = img * 0.5 + 0.5
41
+ img = img.numpy().squeeze()
42
+ ax.set_title(class_names[label])
43
+ ax.imshow(img, cmap=cmap)
44
+ ax.axis('off')
45
+
46
+ plt.tight_layout()
47
+
48
+ return fig, axes
49
+
50
+
51
+ def plot_learning_curves(
52
+ history: dict[str, list[float]],
53
+ figsize: tuple[float, float] = (10, 4)
54
+ ) -> tuple[plt.Figure, np.ndarray]:
55
+ '''Plot training and validation loss and accuracy curves.
56
+
57
+ Args:
58
+ history: Dictionary containing 'train_loss', 'val_loss',
59
+ 'train_accuracy', and 'val_accuracy' lists.
60
+ figsize: Figure size (width, height).
61
+
62
+ Returns:
63
+ Tuple of (figure, axes array).
64
+ '''
65
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
66
+
67
+ axes[0].set_title('Loss')
68
+ axes[0].plot(history['train_loss'], label='Train')
69
+ axes[0].plot(history['val_loss'], label='Validation')
70
+ axes[0].set_xlabel('Epoch')
71
+ axes[0].set_ylabel('Loss (cross-entropy)')
72
+ axes[0].legend(loc='best')
73
+
74
+ axes[1].set_title('Accuracy')
75
+ axes[1].plot(history['train_accuracy'], label='Train')
76
+ axes[1].plot(history['val_accuracy'], label='Validation')
77
+ axes[1].set_xlabel('Epoch')
78
+ axes[1].set_ylabel('Accuracy (%)')
79
+ axes[1].legend(loc='best')
80
+
81
+ plt.tight_layout()
82
+
83
+ return fig, axes
84
+
85
+
86
+ def plot_confusion_matrix(
87
+ true_labels: np.ndarray,
88
+ predictions: np.ndarray,
89
+ class_names: list[str],
90
+ figsize: tuple[float, float] = (8, 8),
91
+ cmap: str = 'Blues'
92
+ ) -> tuple[plt.Figure, plt.Axes]:
93
+ '''Plot a confusion matrix heatmap.
94
+
95
+ Args:
96
+ true_labels: Array of true class labels.
97
+ predictions: Array of predicted class labels.
98
+ class_names: List of class names for labeling.
99
+ figsize: Figure size (width, height).
100
+ cmap: Colormap for the heatmap.
101
+
102
+ Returns:
103
+ Tuple of (figure, axes).
104
+ '''
105
+ from sklearn.metrics import confusion_matrix
106
+
107
+ cm = confusion_matrix(true_labels, predictions)
108
+
109
+ fig, ax = plt.subplots(figsize=figsize)
110
+
111
+ ax.set_title('Confusion matrix')
112
+ im = ax.imshow(cm, cmap=cmap)
113
+
114
+ # Add labels
115
+ ax.set_xticks(range(len(class_names)))
116
+ ax.set_yticks(range(len(class_names)))
117
+ ax.set_xticklabels(class_names, rotation=45, ha='right')
118
+ ax.set_yticklabels(class_names)
119
+ ax.set_xlabel('Predicted label')
120
+ ax.set_ylabel('True label')
121
+
122
+ # Add text annotations
123
+ for i in range(len(class_names)):
124
+ for j in range(len(class_names)):
125
+ color = 'white' if cm[i, j] > cm.max() / 2 else 'black'
126
+ ax.text(j, i, str(cm[i, j]), ha='center', va='center', color=color)
127
+
128
+ plt.tight_layout()
129
+
130
+ return fig, ax
131
+
132
+
133
+ def plot_class_probability_distributions(
134
+ all_probs: np.ndarray,
135
+ class_names: list[str],
136
+ nrows: int = 2,
137
+ ncols: int = 5,
138
+ figsize: tuple[float, float] = (12, 4),
139
+ bins: int = 50,
140
+ color: str = 'black'
141
+ ) -> tuple[plt.Figure, np.ndarray]:
142
+ '''Plot predicted probability distributions for each class.
143
+
144
+ Args:
145
+ all_probs: Array of shape (n_samples, n_classes) with predicted probabilities.
146
+ class_names: List of class names for labeling.
147
+ nrows: Number of rows in the subplot grid.
148
+ ncols: Number of columns in the subplot grid.
149
+ figsize: Figure size (width, height).
150
+ bins: Number of histogram bins.
151
+ color: Histogram bar color.
152
+
153
+ Returns:
154
+ Tuple of (figure, axes array).
155
+ '''
156
+ fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
157
+
158
+ fig.suptitle('Predicted probability distributions by class', fontsize=14, y=1.02)
159
+ fig.supxlabel('Predicted probability', fontsize=12)
160
+ fig.supylabel('Count', fontsize=12)
161
+
162
+ axes = axes.flatten()
163
+
164
+ for i, (ax, class_name) in enumerate(zip(axes, class_names)):
165
+ # Get probabilities for this class across all samples
166
+ class_probs = all_probs[:, i]
167
+
168
+ # Plot histogram
169
+ ax.hist(class_probs, bins=bins, color=color)
170
+ ax.set_title(class_name)
171
+ ax.set_xlim(0, 1)
172
+
173
+ plt.tight_layout()
174
+
175
+ return fig, axes
176
+
177
+
178
+ def plot_evaluation_curves(
179
+ true_labels: np.ndarray,
180
+ all_probs: np.ndarray,
181
+ class_names: list[str],
182
+ figsize: tuple[float, float] = (12, 5)
183
+ ) -> tuple[plt.Figure, tuple[plt.Axes, plt.Axes]]:
184
+ '''Plot ROC and Precision-Recall curves for multi-class classification.
185
+
186
+ Args:
187
+ true_labels: Array of true class labels.
188
+ all_probs: Array of shape (n_samples, n_classes) with predicted probabilities.
189
+ class_names: List of class names for labeling.
190
+ figsize: Figure size (width, height).
191
+
192
+ Returns:
193
+ Tuple of (figure, (ax1, ax2)).
194
+ '''
195
+ from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
196
+ from sklearn.preprocessing import label_binarize
197
+
198
+ # Binarize true labels for one-vs-rest evaluation
199
+ y_test_bin = label_binarize(true_labels, classes=range(len(class_names)))
200
+
201
+ # Create figure with ROC and PR curves side by side
202
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
203
+
204
+ # Plot ROC curves for each class
205
+ ax1.set_title('ROC curves (one-vs-rest)')
206
+
207
+ for i, class_name in enumerate(class_names):
208
+ fpr, tpr, _ = roc_curve(y_test_bin[:, i], all_probs[:, i])
209
+ roc_auc = auc(fpr, tpr)
210
+ ax1.plot(fpr, tpr, label=class_name)
211
+
212
+ ax1.plot([0, 1], [0, 1], 'k--', label='Random classifier')
213
+ ax1.set_xlabel('False positive rate')
214
+ ax1.set_ylabel('True positive rate')
215
+ ax1.legend(loc='lower right', fontsize=12)
216
+ ax1.set_xlim([0, 1])
217
+ ax1.set_ylim([0, 1.05])
218
+
219
+ # Plot Precision-Recall curves for each class
220
+ ax2.set_title('Precision-recall curves (one-vs-rest)')
221
+
222
+ for i, class_name in enumerate(class_names):
223
+ precision, recall, _ = precision_recall_curve(y_test_bin[:, i], all_probs[:, i])
224
+ ap = average_precision_score(y_test_bin[:, i], all_probs[:, i])
225
+ ax2.plot(recall, precision)
226
+
227
+ # Random classifier baseline (horizontal line at class prevalence = 1/num_classes)
228
+ baseline = 1 / len(class_names)
229
+ ax2.axhline(y=baseline, color='k', linestyle='--')
230
+
231
+ ax2.set_xlabel('Recall')
232
+ ax2.set_ylabel('Precision')
233
+ ax2.set_xlim([0, 1])
234
+ ax2.set_ylim([0, 1.05])
235
+
236
+ plt.tight_layout()
237
+
238
+ return fig, (ax1, ax2)
File without changes
File without changes