cifar10-tools 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cifar10_tools/pytorch/data.py +1 -1
- cifar10_tools/pytorch/plotting.py +238 -0
- {cifar10_tools-0.1.0.dist-info → cifar10_tools-0.2.0.dist-info}/METADATA +1 -1
- {cifar10_tools-0.1.0.dist-info → cifar10_tools-0.2.0.dist-info}/RECORD +6 -5
- {cifar10_tools-0.1.0.dist-info → cifar10_tools-0.2.0.dist-info}/WHEEL +0 -0
- {cifar10_tools-0.1.0.dist-info → cifar10_tools-0.2.0.dist-info}/licenses/LICENSE +0 -0
cifar10_tools/pytorch/data.py
CHANGED
|
@@ -4,7 +4,7 @@ during devcontainer creation'''
|
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
from torchvision import datasets
|
|
6
6
|
|
|
7
|
-
def download_cifar10_data(data_dir: str='data/pytorch/
|
|
7
|
+
def download_cifar10_data(data_dir: str='data/pytorch/cifar10'):
|
|
8
8
|
'''Download CIFAR-10 dataset using torchvision.datasets.'''
|
|
9
9
|
|
|
10
10
|
data_dir = Path(data_dir)
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
'''Plotting functions for CIFAR-10 models.'''
|
|
2
|
+
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
import numpy as np
|
|
5
|
+
from torch.utils.data import Dataset
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def plot_sample_images(
|
|
9
|
+
dataset: Dataset,
|
|
10
|
+
class_names: list[str],
|
|
11
|
+
nrows: int = 2,
|
|
12
|
+
ncols: int = 5,
|
|
13
|
+
figsize: tuple[float, float] | None = None,
|
|
14
|
+
cmap: str = 'gray'
|
|
15
|
+
) -> tuple[plt.Figure, np.ndarray]:
|
|
16
|
+
'''Plot sample images from a dataset.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
dataset: PyTorch dataset containing (image, label) tuples.
|
|
20
|
+
class_names: List of class names for labeling.
|
|
21
|
+
nrows: Number of rows in the grid.
|
|
22
|
+
ncols: Number of columns in the grid.
|
|
23
|
+
figsize: Figure size (width, height). Defaults to (ncols*1.5, nrows*1.5).
|
|
24
|
+
cmap: Colormap for displaying images.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
Tuple of (figure, axes array).
|
|
28
|
+
'''
|
|
29
|
+
if figsize is None:
|
|
30
|
+
figsize = (ncols * 1.5, nrows * 1.5)
|
|
31
|
+
|
|
32
|
+
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
|
|
33
|
+
axes = axes.flatten()
|
|
34
|
+
|
|
35
|
+
for i, ax in enumerate(axes):
|
|
36
|
+
# Get image and label from dataset
|
|
37
|
+
img, label = dataset[i]
|
|
38
|
+
|
|
39
|
+
# Unnormalize and squeeze for plotting
|
|
40
|
+
img = img * 0.5 + 0.5
|
|
41
|
+
img = img.numpy().squeeze()
|
|
42
|
+
ax.set_title(class_names[label])
|
|
43
|
+
ax.imshow(img, cmap=cmap)
|
|
44
|
+
ax.axis('off')
|
|
45
|
+
|
|
46
|
+
plt.tight_layout()
|
|
47
|
+
|
|
48
|
+
return fig, axes
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def plot_learning_curves(
|
|
52
|
+
history: dict[str, list[float]],
|
|
53
|
+
figsize: tuple[float, float] = (10, 4)
|
|
54
|
+
) -> tuple[plt.Figure, np.ndarray]:
|
|
55
|
+
'''Plot training and validation loss and accuracy curves.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
history: Dictionary containing 'train_loss', 'val_loss',
|
|
59
|
+
'train_accuracy', and 'val_accuracy' lists.
|
|
60
|
+
figsize: Figure size (width, height).
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
Tuple of (figure, axes array).
|
|
64
|
+
'''
|
|
65
|
+
fig, axes = plt.subplots(1, 2, figsize=figsize)
|
|
66
|
+
|
|
67
|
+
axes[0].set_title('Loss')
|
|
68
|
+
axes[0].plot(history['train_loss'], label='Train')
|
|
69
|
+
axes[0].plot(history['val_loss'], label='Validation')
|
|
70
|
+
axes[0].set_xlabel('Epoch')
|
|
71
|
+
axes[0].set_ylabel('Loss (cross-entropy)')
|
|
72
|
+
axes[0].legend(loc='best')
|
|
73
|
+
|
|
74
|
+
axes[1].set_title('Accuracy')
|
|
75
|
+
axes[1].plot(history['train_accuracy'], label='Train')
|
|
76
|
+
axes[1].plot(history['val_accuracy'], label='Validation')
|
|
77
|
+
axes[1].set_xlabel('Epoch')
|
|
78
|
+
axes[1].set_ylabel('Accuracy (%)')
|
|
79
|
+
axes[1].legend(loc='best')
|
|
80
|
+
|
|
81
|
+
plt.tight_layout()
|
|
82
|
+
|
|
83
|
+
return fig, axes
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def plot_confusion_matrix(
|
|
87
|
+
true_labels: np.ndarray,
|
|
88
|
+
predictions: np.ndarray,
|
|
89
|
+
class_names: list[str],
|
|
90
|
+
figsize: tuple[float, float] = (8, 8),
|
|
91
|
+
cmap: str = 'Blues'
|
|
92
|
+
) -> tuple[plt.Figure, plt.Axes]:
|
|
93
|
+
'''Plot a confusion matrix heatmap.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
true_labels: Array of true class labels.
|
|
97
|
+
predictions: Array of predicted class labels.
|
|
98
|
+
class_names: List of class names for labeling.
|
|
99
|
+
figsize: Figure size (width, height).
|
|
100
|
+
cmap: Colormap for the heatmap.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Tuple of (figure, axes).
|
|
104
|
+
'''
|
|
105
|
+
from sklearn.metrics import confusion_matrix
|
|
106
|
+
|
|
107
|
+
cm = confusion_matrix(true_labels, predictions)
|
|
108
|
+
|
|
109
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
110
|
+
|
|
111
|
+
ax.set_title('Confusion matrix')
|
|
112
|
+
im = ax.imshow(cm, cmap=cmap)
|
|
113
|
+
|
|
114
|
+
# Add labels
|
|
115
|
+
ax.set_xticks(range(len(class_names)))
|
|
116
|
+
ax.set_yticks(range(len(class_names)))
|
|
117
|
+
ax.set_xticklabels(class_names, rotation=45, ha='right')
|
|
118
|
+
ax.set_yticklabels(class_names)
|
|
119
|
+
ax.set_xlabel('Predicted label')
|
|
120
|
+
ax.set_ylabel('True label')
|
|
121
|
+
|
|
122
|
+
# Add text annotations
|
|
123
|
+
for i in range(len(class_names)):
|
|
124
|
+
for j in range(len(class_names)):
|
|
125
|
+
color = 'white' if cm[i, j] > cm.max() / 2 else 'black'
|
|
126
|
+
ax.text(j, i, str(cm[i, j]), ha='center', va='center', color=color)
|
|
127
|
+
|
|
128
|
+
plt.tight_layout()
|
|
129
|
+
|
|
130
|
+
return fig, ax
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def plot_class_probability_distributions(
|
|
134
|
+
all_probs: np.ndarray,
|
|
135
|
+
class_names: list[str],
|
|
136
|
+
nrows: int = 2,
|
|
137
|
+
ncols: int = 5,
|
|
138
|
+
figsize: tuple[float, float] = (12, 4),
|
|
139
|
+
bins: int = 50,
|
|
140
|
+
color: str = 'black'
|
|
141
|
+
) -> tuple[plt.Figure, np.ndarray]:
|
|
142
|
+
'''Plot predicted probability distributions for each class.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
all_probs: Array of shape (n_samples, n_classes) with predicted probabilities.
|
|
146
|
+
class_names: List of class names for labeling.
|
|
147
|
+
nrows: Number of rows in the subplot grid.
|
|
148
|
+
ncols: Number of columns in the subplot grid.
|
|
149
|
+
figsize: Figure size (width, height).
|
|
150
|
+
bins: Number of histogram bins.
|
|
151
|
+
color: Histogram bar color.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
Tuple of (figure, axes array).
|
|
155
|
+
'''
|
|
156
|
+
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
|
|
157
|
+
|
|
158
|
+
fig.suptitle('Predicted probability distributions by class', fontsize=14, y=1.02)
|
|
159
|
+
fig.supxlabel('Predicted probability', fontsize=12)
|
|
160
|
+
fig.supylabel('Count', fontsize=12)
|
|
161
|
+
|
|
162
|
+
axes = axes.flatten()
|
|
163
|
+
|
|
164
|
+
for i, (ax, class_name) in enumerate(zip(axes, class_names)):
|
|
165
|
+
# Get probabilities for this class across all samples
|
|
166
|
+
class_probs = all_probs[:, i]
|
|
167
|
+
|
|
168
|
+
# Plot histogram
|
|
169
|
+
ax.hist(class_probs, bins=bins, color=color)
|
|
170
|
+
ax.set_title(class_name)
|
|
171
|
+
ax.set_xlim(0, 1)
|
|
172
|
+
|
|
173
|
+
plt.tight_layout()
|
|
174
|
+
|
|
175
|
+
return fig, axes
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def plot_evaluation_curves(
|
|
179
|
+
true_labels: np.ndarray,
|
|
180
|
+
all_probs: np.ndarray,
|
|
181
|
+
class_names: list[str],
|
|
182
|
+
figsize: tuple[float, float] = (12, 5)
|
|
183
|
+
) -> tuple[plt.Figure, tuple[plt.Axes, plt.Axes]]:
|
|
184
|
+
'''Plot ROC and Precision-Recall curves for multi-class classification.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
true_labels: Array of true class labels.
|
|
188
|
+
all_probs: Array of shape (n_samples, n_classes) with predicted probabilities.
|
|
189
|
+
class_names: List of class names for labeling.
|
|
190
|
+
figsize: Figure size (width, height).
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
Tuple of (figure, (ax1, ax2)).
|
|
194
|
+
'''
|
|
195
|
+
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
|
|
196
|
+
from sklearn.preprocessing import label_binarize
|
|
197
|
+
|
|
198
|
+
# Binarize true labels for one-vs-rest evaluation
|
|
199
|
+
y_test_bin = label_binarize(true_labels, classes=range(len(class_names)))
|
|
200
|
+
|
|
201
|
+
# Create figure with ROC and PR curves side by side
|
|
202
|
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
|
|
203
|
+
|
|
204
|
+
# Plot ROC curves for each class
|
|
205
|
+
ax1.set_title('ROC curves (one-vs-rest)')
|
|
206
|
+
|
|
207
|
+
for i, class_name in enumerate(class_names):
|
|
208
|
+
fpr, tpr, _ = roc_curve(y_test_bin[:, i], all_probs[:, i])
|
|
209
|
+
roc_auc = auc(fpr, tpr)
|
|
210
|
+
ax1.plot(fpr, tpr, label=class_name)
|
|
211
|
+
|
|
212
|
+
ax1.plot([0, 1], [0, 1], 'k--', label='Random classifier')
|
|
213
|
+
ax1.set_xlabel('False positive rate')
|
|
214
|
+
ax1.set_ylabel('True positive rate')
|
|
215
|
+
ax1.legend(loc='lower right', fontsize=12)
|
|
216
|
+
ax1.set_xlim([0, 1])
|
|
217
|
+
ax1.set_ylim([0, 1.05])
|
|
218
|
+
|
|
219
|
+
# Plot Precision-Recall curves for each class
|
|
220
|
+
ax2.set_title('Precision-recall curves (one-vs-rest)')
|
|
221
|
+
|
|
222
|
+
for i, class_name in enumerate(class_names):
|
|
223
|
+
precision, recall, _ = precision_recall_curve(y_test_bin[:, i], all_probs[:, i])
|
|
224
|
+
ap = average_precision_score(y_test_bin[:, i], all_probs[:, i])
|
|
225
|
+
ax2.plot(recall, precision)
|
|
226
|
+
|
|
227
|
+
# Random classifier baseline (horizontal line at class prevalence = 1/num_classes)
|
|
228
|
+
baseline = 1 / len(class_names)
|
|
229
|
+
ax2.axhline(y=baseline, color='k', linestyle='--')
|
|
230
|
+
|
|
231
|
+
ax2.set_xlabel('Recall')
|
|
232
|
+
ax2.set_ylabel('Precision')
|
|
233
|
+
ax2.set_xlim([0, 1])
|
|
234
|
+
ax2.set_ylim([0, 1.05])
|
|
235
|
+
|
|
236
|
+
plt.tight_layout()
|
|
237
|
+
|
|
238
|
+
return fig, (ax1, ax2)
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
cifar10_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
cifar10_tools/pytorch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
-
cifar10_tools/pytorch/data.py,sha256=
|
|
3
|
+
cifar10_tools/pytorch/data.py,sha256=09zodpjto0xLq95tDAyq57CFh6MSYRuUBPcMmQcyKZM,626
|
|
4
4
|
cifar10_tools/pytorch/evaluation.py,sha256=i4tRYOqWATVqQVkWT_fATWRbzo9ziX2DDkXKPaiQlFE,923
|
|
5
|
+
cifar10_tools/pytorch/plotting.py,sha256=B1ifJxbSEDpInnVk9c3o1fjVx534TPPKTWM5iusyzrE,7494
|
|
5
6
|
cifar10_tools/pytorch/training.py,sha256=Sg6NlBT_DTyLzf-Ls3bYI8-8AwGFJblRj0MDnUmGP3Q,2642
|
|
6
7
|
cifar10_tools/tensorflow/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
7
|
-
cifar10_tools-0.
|
|
8
|
-
cifar10_tools-0.
|
|
9
|
-
cifar10_tools-0.
|
|
10
|
-
cifar10_tools-0.
|
|
8
|
+
cifar10_tools-0.2.0.dist-info/METADATA,sha256=3s6_5lP8rAnEu5F9r5YKU-EqUi9UO3mNUFK1ikVgUfc,1580
|
|
9
|
+
cifar10_tools-0.2.0.dist-info/WHEEL,sha256=kJCRJT_g0adfAJzTx2GUMmS80rTJIVHRCfG0DQgLq3o,88
|
|
10
|
+
cifar10_tools-0.2.0.dist-info/licenses/LICENSE,sha256=wtHfRwmCF5-_XUmYwrBKwJkGipvHVmh7GXJOKKeOe2U,1073
|
|
11
|
+
cifar10_tools-0.2.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|