megadetector 10.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- megadetector/__init__.py +0 -0
- megadetector/api/__init__.py +0 -0
- megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
- megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
- megadetector/classification/__init__.py +0 -0
- megadetector/classification/aggregate_classifier_probs.py +108 -0
- megadetector/classification/analyze_failed_images.py +227 -0
- megadetector/classification/cache_batchapi_outputs.py +198 -0
- megadetector/classification/create_classification_dataset.py +626 -0
- megadetector/classification/crop_detections.py +516 -0
- megadetector/classification/csv_to_json.py +226 -0
- megadetector/classification/detect_and_crop.py +853 -0
- megadetector/classification/efficientnet/__init__.py +9 -0
- megadetector/classification/efficientnet/model.py +415 -0
- megadetector/classification/efficientnet/utils.py +608 -0
- megadetector/classification/evaluate_model.py +520 -0
- megadetector/classification/identify_mislabeled_candidates.py +152 -0
- megadetector/classification/json_to_azcopy_list.py +63 -0
- megadetector/classification/json_validator.py +696 -0
- megadetector/classification/map_classification_categories.py +276 -0
- megadetector/classification/merge_classification_detection_output.py +509 -0
- megadetector/classification/prepare_classification_script.py +194 -0
- megadetector/classification/prepare_classification_script_mc.py +228 -0
- megadetector/classification/run_classifier.py +287 -0
- megadetector/classification/save_mislabeled.py +110 -0
- megadetector/classification/train_classifier.py +827 -0
- megadetector/classification/train_classifier_tf.py +725 -0
- megadetector/classification/train_utils.py +323 -0
- megadetector/data_management/__init__.py +0 -0
- megadetector/data_management/animl_to_md.py +161 -0
- megadetector/data_management/annotations/__init__.py +0 -0
- megadetector/data_management/annotations/annotation_constants.py +33 -0
- megadetector/data_management/camtrap_dp_to_coco.py +270 -0
- megadetector/data_management/cct_json_utils.py +566 -0
- megadetector/data_management/cct_to_md.py +184 -0
- megadetector/data_management/cct_to_wi.py +293 -0
- megadetector/data_management/coco_to_labelme.py +284 -0
- megadetector/data_management/coco_to_yolo.py +701 -0
- megadetector/data_management/databases/__init__.py +0 -0
- megadetector/data_management/databases/add_width_and_height_to_db.py +107 -0
- megadetector/data_management/databases/combine_coco_camera_traps_files.py +210 -0
- megadetector/data_management/databases/integrity_check_json_db.py +563 -0
- megadetector/data_management/databases/subset_json_db.py +195 -0
- megadetector/data_management/generate_crops_from_cct.py +200 -0
- megadetector/data_management/get_image_sizes.py +164 -0
- megadetector/data_management/labelme_to_coco.py +559 -0
- megadetector/data_management/labelme_to_yolo.py +349 -0
- megadetector/data_management/lila/__init__.py +0 -0
- megadetector/data_management/lila/create_lila_blank_set.py +556 -0
- megadetector/data_management/lila/create_lila_test_set.py +192 -0
- megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
- megadetector/data_management/lila/download_lila_subset.py +182 -0
- megadetector/data_management/lila/generate_lila_per_image_labels.py +777 -0
- megadetector/data_management/lila/get_lila_annotation_counts.py +174 -0
- megadetector/data_management/lila/get_lila_image_counts.py +112 -0
- megadetector/data_management/lila/lila_common.py +319 -0
- megadetector/data_management/lila/test_lila_metadata_urls.py +164 -0
- megadetector/data_management/mewc_to_md.py +344 -0
- megadetector/data_management/ocr_tools.py +873 -0
- megadetector/data_management/read_exif.py +964 -0
- megadetector/data_management/remap_coco_categories.py +195 -0
- megadetector/data_management/remove_exif.py +156 -0
- megadetector/data_management/rename_images.py +194 -0
- megadetector/data_management/resize_coco_dataset.py +665 -0
- megadetector/data_management/speciesnet_to_md.py +41 -0
- megadetector/data_management/wi_download_csv_to_coco.py +247 -0
- megadetector/data_management/yolo_output_to_md_output.py +594 -0
- megadetector/data_management/yolo_to_coco.py +984 -0
- megadetector/data_management/zamba_to_md.py +188 -0
- megadetector/detection/__init__.py +0 -0
- megadetector/detection/change_detection.py +840 -0
- megadetector/detection/process_video.py +479 -0
- megadetector/detection/pytorch_detector.py +1451 -0
- megadetector/detection/run_detector.py +1267 -0
- megadetector/detection/run_detector_batch.py +2172 -0
- megadetector/detection/run_inference_with_yolov5_val.py +1314 -0
- megadetector/detection/run_md_and_speciesnet.py +1604 -0
- megadetector/detection/run_tiled_inference.py +1044 -0
- megadetector/detection/tf_detector.py +209 -0
- megadetector/detection/video_utils.py +1379 -0
- megadetector/postprocessing/__init__.py +0 -0
- megadetector/postprocessing/add_max_conf.py +72 -0
- megadetector/postprocessing/categorize_detections_by_size.py +166 -0
- megadetector/postprocessing/classification_postprocessing.py +1943 -0
- megadetector/postprocessing/combine_batch_outputs.py +249 -0
- megadetector/postprocessing/compare_batch_results.py +2110 -0
- megadetector/postprocessing/convert_output_format.py +403 -0
- megadetector/postprocessing/create_crop_folder.py +629 -0
- megadetector/postprocessing/detector_calibration.py +570 -0
- megadetector/postprocessing/generate_csv_report.py +522 -0
- megadetector/postprocessing/load_api_results.py +223 -0
- megadetector/postprocessing/md_to_coco.py +428 -0
- megadetector/postprocessing/md_to_labelme.py +351 -0
- megadetector/postprocessing/md_to_wi.py +41 -0
- megadetector/postprocessing/merge_detections.py +392 -0
- megadetector/postprocessing/postprocess_batch_results.py +2140 -0
- megadetector/postprocessing/remap_detection_categories.py +226 -0
- megadetector/postprocessing/render_detection_confusion_matrix.py +677 -0
- megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +206 -0
- megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +82 -0
- megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1665 -0
- megadetector/postprocessing/separate_detections_into_folders.py +795 -0
- megadetector/postprocessing/subset_json_detector_output.py +964 -0
- megadetector/postprocessing/top_folders_to_bottom.py +238 -0
- megadetector/postprocessing/validate_batch_results.py +332 -0
- megadetector/taxonomy_mapping/__init__.py +0 -0
- megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
- megadetector/taxonomy_mapping/map_new_lila_datasets.py +211 -0
- megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +165 -0
- megadetector/taxonomy_mapping/preview_lila_taxonomy.py +543 -0
- megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
- megadetector/taxonomy_mapping/simple_image_download.py +231 -0
- megadetector/taxonomy_mapping/species_lookup.py +1008 -0
- megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
- megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
- megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
- megadetector/tests/__init__.py +0 -0
- megadetector/tests/test_nms_synthetic.py +335 -0
- megadetector/utils/__init__.py +0 -0
- megadetector/utils/ct_utils.py +1857 -0
- megadetector/utils/directory_listing.py +199 -0
- megadetector/utils/extract_frames_from_video.py +307 -0
- megadetector/utils/gpu_test.py +125 -0
- megadetector/utils/md_tests.py +2072 -0
- megadetector/utils/path_utils.py +2872 -0
- megadetector/utils/process_utils.py +172 -0
- megadetector/utils/split_locations_into_train_val.py +237 -0
- megadetector/utils/string_utils.py +234 -0
- megadetector/utils/url_utils.py +825 -0
- megadetector/utils/wi_platform_utils.py +968 -0
- megadetector/utils/wi_taxonomy_utils.py +1766 -0
- megadetector/utils/write_html_image_list.py +239 -0
- megadetector/visualization/__init__.py +0 -0
- megadetector/visualization/plot_utils.py +309 -0
- megadetector/visualization/render_images_with_thumbnails.py +243 -0
- megadetector/visualization/visualization_utils.py +1973 -0
- megadetector/visualization/visualize_db.py +630 -0
- megadetector/visualization/visualize_detector_output.py +498 -0
- megadetector/visualization/visualize_video_output.py +705 -0
- megadetector-10.0.15.dist-info/METADATA +115 -0
- megadetector-10.0.15.dist-info/RECORD +147 -0
- megadetector-10.0.15.dist-info/WHEEL +5 -0
- megadetector-10.0.15.dist-info/licenses/LICENSE +19 -0
- megadetector-10.0.15.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,827 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
train_classifier.py
|
|
4
|
+
|
|
5
|
+
Train a EfficientNet or ResNet classifier.
|
|
6
|
+
|
|
7
|
+
Currently the implementation of multi-label multi-class classification is
|
|
8
|
+
non-functional.
|
|
9
|
+
|
|
10
|
+
During training, start tensorboard from within the classification/ directory:
|
|
11
|
+
tensorboard --logdir run --bind_all --samples_per_plugin scalars=0,images=0
|
|
12
|
+
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
#%% Imports and constants
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import json
|
|
20
|
+
import os
|
|
21
|
+
import argparse
|
|
22
|
+
|
|
23
|
+
from collections import defaultdict
|
|
24
|
+
from collections.abc import Callable, Mapping, MutableMapping, Sequence
|
|
25
|
+
from datetime import datetime
|
|
26
|
+
from typing import Any
|
|
27
|
+
|
|
28
|
+
import numpy as np
|
|
29
|
+
import PIL.Image
|
|
30
|
+
import sklearn.metrics
|
|
31
|
+
import tqdm
|
|
32
|
+
|
|
33
|
+
import torch
|
|
34
|
+
from torch.utils import tensorboard
|
|
35
|
+
import torchvision as tv
|
|
36
|
+
from torchvision.datasets.folder import default_loader
|
|
37
|
+
|
|
38
|
+
from megadetector.classification import efficientnet, evaluate_model
|
|
39
|
+
from megadetector.classification.train_utils import (
|
|
40
|
+
HeapItem, recall_from_confusion_matrix, add_to_heap, fig_to_img,
|
|
41
|
+
imgs_with_confidences, load_dataset_csv, prefix_all_keys)
|
|
42
|
+
from megadetector.visualization import plot_utils
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
#%% Example usage
|
|
46
|
+
|
|
47
|
+
"""
|
|
48
|
+
python train_classifier.py run_idfg /ssd/crops_sq \
|
|
49
|
+
-m "efficientnet-b0" --pretrained --finetune --label-weighted \
|
|
50
|
+
--epochs 50 --batch-size 512 --lr 1e-4 \
|
|
51
|
+
--num-workers 12 --seed 123 \
|
|
52
|
+
--logdir run_idfg
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
# mean/std values from https://pytorch.org/docs/stable/torchvision/models.html
|
|
56
|
+
MEANS = np.asarray([0.485, 0.456, 0.406])
|
|
57
|
+
STDS = np.asarray([0.229, 0.224, 0.225])
|
|
58
|
+
|
|
59
|
+
VALID_MODELS = sorted(
|
|
60
|
+
set(efficientnet.VALID_MODELS) |
|
|
61
|
+
{'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50'})
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class AverageMeter:
|
|
65
|
+
"""Computes and stores the average and current value"""
|
|
66
|
+
def __init__(self):
|
|
67
|
+
self.reset()
|
|
68
|
+
|
|
69
|
+
def reset(self) -> None:
|
|
70
|
+
self.val = 0.0
|
|
71
|
+
self.avg = 0.0
|
|
72
|
+
self.sum = 0.0
|
|
73
|
+
self.count = 0
|
|
74
|
+
|
|
75
|
+
def update(self, val: float, n: int = 1) -> None:
|
|
76
|
+
self.val = val
|
|
77
|
+
self.sum += val * n
|
|
78
|
+
self.count += n
|
|
79
|
+
self.avg = self.sum / self.count
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class SimpleDataset(torch.utils.data.Dataset):
|
|
83
|
+
"""A simple dataset that simply returns images and labels."""
|
|
84
|
+
|
|
85
|
+
def __init__(self,
|
|
86
|
+
img_files: Sequence[str],
|
|
87
|
+
labels: Sequence[Any],
|
|
88
|
+
sample_weights: Sequence[float] | None = None,
|
|
89
|
+
img_base_dir: str = '',
|
|
90
|
+
transform: Callable[[PIL.Image.Image], Any] | None = None,
|
|
91
|
+
target_transform: Callable[[Any], Any] | None = None):
|
|
92
|
+
"""Creates a SimpleDataset."""
|
|
93
|
+
self.img_files = img_files
|
|
94
|
+
self.labels = labels
|
|
95
|
+
self.sample_weights = sample_weights
|
|
96
|
+
self.img_base_dir = img_base_dir
|
|
97
|
+
self.transform = transform
|
|
98
|
+
self.target_transform = target_transform
|
|
99
|
+
|
|
100
|
+
self.len = len(img_files)
|
|
101
|
+
assert len(labels) == self.len
|
|
102
|
+
if sample_weights is not None:
|
|
103
|
+
assert len(sample_weights) == self.len
|
|
104
|
+
|
|
105
|
+
def __getitem__(self, index: int) -> tuple[Any, ...]:
|
|
106
|
+
"""
|
|
107
|
+
Args:
|
|
108
|
+
index: int
|
|
109
|
+
|
|
110
|
+
Returns: tuple, (sample, target) or (sample, target, sample_weight)
|
|
111
|
+
"""
|
|
112
|
+
img_file = self.img_files[index]
|
|
113
|
+
img = default_loader(os.path.join(self.img_base_dir, img_file))
|
|
114
|
+
if self.transform is not None:
|
|
115
|
+
img = self.transform(img)
|
|
116
|
+
target = self.labels[index]
|
|
117
|
+
if self.target_transform is not None:
|
|
118
|
+
target = self.target_transform(target)
|
|
119
|
+
if self.sample_weights is not None:
|
|
120
|
+
return img, target, img_file, self.sample_weights[index]
|
|
121
|
+
return img, target, img_file
|
|
122
|
+
|
|
123
|
+
def __len__(self) -> int:
|
|
124
|
+
return self.len
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def create_dataloaders(
|
|
128
|
+
dataset_csv_path: str,
|
|
129
|
+
label_index_json_path: str,
|
|
130
|
+
splits_json_path: str,
|
|
131
|
+
cropped_images_dir: str,
|
|
132
|
+
img_size: int,
|
|
133
|
+
multilabel: bool,
|
|
134
|
+
label_weighted: bool,
|
|
135
|
+
weight_by_detection_conf: bool | str,
|
|
136
|
+
batch_size: int,
|
|
137
|
+
num_workers: int,
|
|
138
|
+
augment_train: bool
|
|
139
|
+
) -> tuple[dict[str, torch.utils.data.DataLoader], list[str]]:
|
|
140
|
+
"""
|
|
141
|
+
Args:
|
|
142
|
+
dataset_csv_path: str, path to CSV file with columns
|
|
143
|
+
['dataset', 'location', 'label'], where label is a comma-delimited
|
|
144
|
+
list of labels
|
|
145
|
+
splits_json_path: str, path to JSON file
|
|
146
|
+
augment_train: bool, whether to shuffle/augment the training set
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
datasets: dict, maps split to DataLoader
|
|
150
|
+
label_names: list of str, label names in order of label id
|
|
151
|
+
"""
|
|
152
|
+
df, label_names, split_to_locs = load_dataset_csv(
|
|
153
|
+
dataset_csv_path, label_index_json_path, splits_json_path,
|
|
154
|
+
multilabel=multilabel, label_weighted=label_weighted,
|
|
155
|
+
weight_by_detection_conf=weight_by_detection_conf)
|
|
156
|
+
|
|
157
|
+
# define the transforms
|
|
158
|
+
normalize = tv.transforms.Normalize(mean=MEANS, std=STDS, inplace=True)
|
|
159
|
+
train_transform = tv.transforms.Compose([
|
|
160
|
+
tv.transforms.RandomResizedCrop(img_size),
|
|
161
|
+
tv.transforms.RandomRotation(degrees=(-90, 90)),
|
|
162
|
+
tv.transforms.RandomHorizontalFlip(p=0.5),
|
|
163
|
+
tv.transforms.RandomVerticalFlip(p=0.1),
|
|
164
|
+
tv.transforms.RandomGrayscale(p=0.1),
|
|
165
|
+
tv.transforms.ColorJitter(brightness=.25, contrast=.25, saturation=.25),
|
|
166
|
+
tv.transforms.ToTensor(),
|
|
167
|
+
normalize
|
|
168
|
+
])
|
|
169
|
+
test_transform = tv.transforms.Compose([
|
|
170
|
+
# resizes smaller edge to img_size
|
|
171
|
+
tv.transforms.Resize(img_size, interpolation=PIL.Image.BICUBIC),
|
|
172
|
+
tv.transforms.CenterCrop(img_size),
|
|
173
|
+
tv.transforms.ToTensor(),
|
|
174
|
+
normalize
|
|
175
|
+
])
|
|
176
|
+
|
|
177
|
+
dataloaders = {}
|
|
178
|
+
for split, locs in split_to_locs.items():
|
|
179
|
+
is_train = (split == 'train') and augment_train
|
|
180
|
+
split_df = df[df['dataset_location'].isin(locs)]
|
|
181
|
+
|
|
182
|
+
sampler: torch.utils.data.Sampler | None = None
|
|
183
|
+
weights = None
|
|
184
|
+
if label_weighted or weight_by_detection_conf:
|
|
185
|
+
# weights sums to:
|
|
186
|
+
# - if weight_by_detection_conf: (# images in split - conf delta)
|
|
187
|
+
# - otherwise: # images in split
|
|
188
|
+
weights = split_df['weights'].to_numpy()
|
|
189
|
+
if not weight_by_detection_conf:
|
|
190
|
+
assert np.isclose(weights.sum(), len(split_df))
|
|
191
|
+
if is_train:
|
|
192
|
+
sampler = torch.utils.data.WeightedRandomSampler(
|
|
193
|
+
weights, num_samples=len(split_df), replacement=True)
|
|
194
|
+
elif is_train:
|
|
195
|
+
# for normal (non-weighted) shuffling
|
|
196
|
+
sampler = torch.utils.data.SubsetRandomSampler(range(len(split_df)))
|
|
197
|
+
|
|
198
|
+
dataset = SimpleDataset(
|
|
199
|
+
img_files=split_df['path'].tolist(),
|
|
200
|
+
labels=split_df['label_index'].tolist(),
|
|
201
|
+
sample_weights=weights,
|
|
202
|
+
img_base_dir=cropped_images_dir,
|
|
203
|
+
transform=train_transform if is_train else test_transform)
|
|
204
|
+
assert len(dataset) > 0
|
|
205
|
+
dataloaders[split] = torch.utils.data.DataLoader(
|
|
206
|
+
dataset, batch_size=batch_size, sampler=sampler,
|
|
207
|
+
num_workers=num_workers, pin_memory=True)
|
|
208
|
+
|
|
209
|
+
return dataloaders, label_names
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def set_finetune(model: torch.nn.Module, model_name: str, finetune: bool
|
|
213
|
+
) -> None:
|
|
214
|
+
"""Set the 'requires_grad' on each model parameter according to whether or
|
|
215
|
+
not we are fine-tuning the model.
|
|
216
|
+
"""
|
|
217
|
+
if finetune:
|
|
218
|
+
if 'efficientnet' in model_name:
|
|
219
|
+
final_layer = model._fc # pylint: disable=protected-access
|
|
220
|
+
else: # torchvision resnet
|
|
221
|
+
final_layer = model.fc
|
|
222
|
+
assert isinstance(final_layer, torch.nn.Module)
|
|
223
|
+
|
|
224
|
+
# set all parameters to not require gradients except final FC layer
|
|
225
|
+
model.requires_grad_(False)
|
|
226
|
+
for param in final_layer.parameters():
|
|
227
|
+
param.requires_grad = True
|
|
228
|
+
else:
|
|
229
|
+
model.requires_grad_(True)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def build_model(model_name: str, num_classes: int, pretrained: bool | str,
|
|
233
|
+
finetune: bool) -> torch.nn.Module:
|
|
234
|
+
"""Creates a model with an EfficientNet or ResNet base. The model outputs
|
|
235
|
+
unnormalized logits.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
model_name: str, name of EfficientNet or Resnet model
|
|
239
|
+
num_classes: int, number of classes for output layer
|
|
240
|
+
pretrained: bool or str, (bool) whether to initialize to ImageNet
|
|
241
|
+
weights, (str) path to checkpoint
|
|
242
|
+
finetune: bool, whether to freeze all layers except the final FC layer
|
|
243
|
+
|
|
244
|
+
Returns: torch.nn.Module, model loaded on CPU
|
|
245
|
+
"""
|
|
246
|
+
assert model_name in VALID_MODELS
|
|
247
|
+
|
|
248
|
+
if 'efficientnet' in model_name:
|
|
249
|
+
if pretrained is True:
|
|
250
|
+
model = efficientnet.EfficientNet.from_pretrained(
|
|
251
|
+
model_name, num_classes=num_classes)
|
|
252
|
+
else:
|
|
253
|
+
model = efficientnet.EfficientNet.from_name(
|
|
254
|
+
model_name, num_classes=num_classes)
|
|
255
|
+
else:
|
|
256
|
+
model_class = getattr(tv.models, model_name)
|
|
257
|
+
model = model_class(pretrained=(pretrained is True))
|
|
258
|
+
|
|
259
|
+
# replace final fully-connected layer (which has 1000 ImageNet classes)
|
|
260
|
+
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
|
|
261
|
+
|
|
262
|
+
if isinstance(pretrained, str):
|
|
263
|
+
print(f'Loading saved weights from {pretrained}')
|
|
264
|
+
ckpt = torch.load(pretrained, map_location='cpu')
|
|
265
|
+
model.load_state_dict(ckpt['model'])
|
|
266
|
+
|
|
267
|
+
assert all(p.requires_grad for p in model.parameters())
|
|
268
|
+
set_finetune(model=model, model_name=model_name, finetune=finetune)
|
|
269
|
+
return model
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def prep_device(model: torch.nn.Module, device_id: int | None = None
|
|
273
|
+
) -> tuple[torch.nn.Module, torch.device]:
|
|
274
|
+
"""Place model on appropriate device.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
model: torch.nn.Module, not already wrapped with DataParallel
|
|
278
|
+
device_id: optional int, GPU device to use
|
|
279
|
+
if None, then uses DataParallel when possible
|
|
280
|
+
if specified, then only uses specified device
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
model: torch.nn.Module, model placed on <device>, wrapped with
|
|
284
|
+
DataParallel if more than 1 GPU is found
|
|
285
|
+
device: torch.device, 'cuda:{device_id}' if GPU is found, otherwise 'cpu'
|
|
286
|
+
"""
|
|
287
|
+
# detect GPU, use all if available
|
|
288
|
+
if torch.cuda.is_available():
|
|
289
|
+
print('CUDA available')
|
|
290
|
+
torch.backends.cudnn.benchmark = True
|
|
291
|
+
if device_id is not None:
|
|
292
|
+
print(f'Starting CUDA device {device_id}')
|
|
293
|
+
device = torch.device(f'cuda:{device_id}')
|
|
294
|
+
else:
|
|
295
|
+
device = torch.device('cuda:0')
|
|
296
|
+
device_ids = list(range(torch.cuda.device_count()))
|
|
297
|
+
if len(device_ids) > 1:
|
|
298
|
+
print(f'Found multiple devices, enabling data parallelism ({device_ids})')
|
|
299
|
+
model = torch.nn.DataParallel(model, device_ids=device_ids)
|
|
300
|
+
else:
|
|
301
|
+
print('CUDA not available, running on the CPU')
|
|
302
|
+
device = torch.device('cpu')
|
|
303
|
+
model.to(device) # in-place
|
|
304
|
+
return model, device
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def main(dataset_dir: str,
|
|
308
|
+
cropped_images_dir: str,
|
|
309
|
+
multilabel: bool,
|
|
310
|
+
model_name: str,
|
|
311
|
+
pretrained: bool | str,
|
|
312
|
+
finetune: int,
|
|
313
|
+
label_weighted: bool,
|
|
314
|
+
weight_by_detection_conf: bool | str,
|
|
315
|
+
epochs: int,
|
|
316
|
+
batch_size: int,
|
|
317
|
+
lr: float,
|
|
318
|
+
weight_decay: float,
|
|
319
|
+
num_workers: int,
|
|
320
|
+
logdir: str,
|
|
321
|
+
log_extreme_examples: int,
|
|
322
|
+
seed: int | None = None) -> None:
|
|
323
|
+
"""Main function."""
|
|
324
|
+
# input validation
|
|
325
|
+
assert os.path.exists(dataset_dir)
|
|
326
|
+
assert os.path.exists(cropped_images_dir)
|
|
327
|
+
if isinstance(weight_by_detection_conf, str):
|
|
328
|
+
assert os.path.exists(weight_by_detection_conf)
|
|
329
|
+
if isinstance(pretrained, str):
|
|
330
|
+
assert os.path.exists(pretrained)
|
|
331
|
+
|
|
332
|
+
# set seed
|
|
333
|
+
seed = np.random.randint(10_000) if seed is None else seed
|
|
334
|
+
np.random.seed(seed)
|
|
335
|
+
torch.manual_seed(seed)
|
|
336
|
+
torch.cuda.manual_seed_all(seed)
|
|
337
|
+
|
|
338
|
+
# create logdir and save params
|
|
339
|
+
params = dict(locals()) # make a copy
|
|
340
|
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') # '20200722_110816'
|
|
341
|
+
logdir = os.path.join(logdir, timestamp)
|
|
342
|
+
os.makedirs(logdir, exist_ok=True)
|
|
343
|
+
print('Created logdir:', logdir)
|
|
344
|
+
params_json_path = os.path.join(logdir, 'params.json')
|
|
345
|
+
with open(params_json_path, 'w') as f:
|
|
346
|
+
json.dump(params, f, indent=1)
|
|
347
|
+
|
|
348
|
+
if 'efficientnet' in model_name:
|
|
349
|
+
img_size = efficientnet.EfficientNet.get_image_size(model_name)
|
|
350
|
+
else:
|
|
351
|
+
img_size = 224
|
|
352
|
+
|
|
353
|
+
# create dataloaders and log the index_to_label mapping
|
|
354
|
+
print('Creating dataloaders')
|
|
355
|
+
loaders, label_names = create_dataloaders(
|
|
356
|
+
dataset_csv_path=os.path.join(dataset_dir, 'classification_ds.csv'),
|
|
357
|
+
label_index_json_path=os.path.join(dataset_dir, 'label_index.json'),
|
|
358
|
+
splits_json_path=os.path.join(dataset_dir, 'splits.json'),
|
|
359
|
+
cropped_images_dir=cropped_images_dir,
|
|
360
|
+
img_size=img_size,
|
|
361
|
+
multilabel=multilabel,
|
|
362
|
+
label_weighted=label_weighted,
|
|
363
|
+
weight_by_detection_conf=weight_by_detection_conf,
|
|
364
|
+
batch_size=batch_size,
|
|
365
|
+
num_workers=num_workers,
|
|
366
|
+
augment_train=True)
|
|
367
|
+
|
|
368
|
+
writer = tensorboard.SummaryWriter(logdir)
|
|
369
|
+
|
|
370
|
+
# create model
|
|
371
|
+
model = build_model(model_name, num_classes=len(label_names),
|
|
372
|
+
pretrained=pretrained, finetune=finetune > 0)
|
|
373
|
+
model, device = prep_device(model)
|
|
374
|
+
|
|
375
|
+
# define loss function and optimizer
|
|
376
|
+
loss_fn: torch.nn.Module
|
|
377
|
+
if multilabel:
|
|
378
|
+
loss_fn = torch.nn.BCEWithLogitsLoss(reduction='none').to(device)
|
|
379
|
+
else:
|
|
380
|
+
loss_fn = torch.nn.CrossEntropyLoss(reduction='none').to(device)
|
|
381
|
+
|
|
382
|
+
# using EfficientNet training defaults
|
|
383
|
+
# - batch norm momentum: 0.99
|
|
384
|
+
# - optimizer: RMSProp, decay 0.9 and momentum 0.9
|
|
385
|
+
# - epochs: 350
|
|
386
|
+
# - learning rate: 0.256, decays by 0.97 every 2.4 epochs
|
|
387
|
+
# - weight decay: 1e-5
|
|
388
|
+
optimizer: torch.optim.Optimizer
|
|
389
|
+
if 'efficientnet' in model_name:
|
|
390
|
+
optimizer = torch.optim.RMSprop(model.parameters(), lr, alpha=0.9,
|
|
391
|
+
momentum=0.9, weight_decay=weight_decay)
|
|
392
|
+
lr_scheduler = torch.optim.lr_scheduler.StepLR(
|
|
393
|
+
optimizer=optimizer, step_size=1, gamma=0.97 ** (1 / 2.4))
|
|
394
|
+
else: # resnet
|
|
395
|
+
optimizer = torch.optim.SGD(model.parameters(), lr, momentum=0.9,
|
|
396
|
+
weight_decay=weight_decay)
|
|
397
|
+
lr_scheduler = torch.optim.lr_scheduler.StepLR(
|
|
398
|
+
optimizer=optimizer, step_size=8, gamma=0.1) # lower every 8 epochs
|
|
399
|
+
|
|
400
|
+
best_epoch_metrics: dict[str, float] = {}
|
|
401
|
+
for epoch in range(epochs):
|
|
402
|
+
print(f'Epoch: {epoch}')
|
|
403
|
+
writer.add_scalar('lr', lr_scheduler.get_last_lr()[0], epoch)
|
|
404
|
+
|
|
405
|
+
if epoch > 0 and finetune == epoch:
|
|
406
|
+
print('Turning off fine-tune!')
|
|
407
|
+
set_finetune(model, model_name, finetune=False)
|
|
408
|
+
|
|
409
|
+
print('- train:')
|
|
410
|
+
train_metrics, train_heaps, train_cm = run_epoch(
|
|
411
|
+
model, loader=loaders['train'], weighted=False, device=device,
|
|
412
|
+
loss_fn=loss_fn, finetune=finetune > epoch, optimizer=optimizer,
|
|
413
|
+
k_extreme=log_extreme_examples)
|
|
414
|
+
train_metrics = prefix_all_keys(train_metrics, prefix='train/')
|
|
415
|
+
log_run('train', epoch, writer, label_names,
|
|
416
|
+
metrics=train_metrics, heaps=train_heaps, cm=train_cm)
|
|
417
|
+
del train_heaps
|
|
418
|
+
|
|
419
|
+
print('- val:')
|
|
420
|
+
val_metrics, val_heaps, val_cm = run_epoch(
|
|
421
|
+
model, loader=loaders['val'], weighted=label_weighted,
|
|
422
|
+
device=device, loss_fn=loss_fn, k_extreme=log_extreme_examples)
|
|
423
|
+
val_metrics = prefix_all_keys(val_metrics, prefix='val/')
|
|
424
|
+
log_run('val', epoch, writer, label_names,
|
|
425
|
+
metrics=val_metrics, heaps=val_heaps, cm=val_cm)
|
|
426
|
+
del val_heaps
|
|
427
|
+
|
|
428
|
+
lr_scheduler.step() # decrease the learning rate
|
|
429
|
+
|
|
430
|
+
if val_metrics['val/acc_top1'] > best_epoch_metrics.get('val/acc_top1', 0): # pylint: disable=line-too-long
|
|
431
|
+
filename = os.path.join(logdir, f'ckpt_{epoch}.pt')
|
|
432
|
+
print(f'New best model! Saving checkpoint to {filename}')
|
|
433
|
+
state = {
|
|
434
|
+
'epoch': epoch,
|
|
435
|
+
'model': getattr(model, 'module', model).state_dict(),
|
|
436
|
+
'val/acc': val_metrics['val/acc_top1'],
|
|
437
|
+
'optimizer': optimizer.state_dict()
|
|
438
|
+
}
|
|
439
|
+
torch.save(state, filename)
|
|
440
|
+
best_epoch_metrics.update(train_metrics)
|
|
441
|
+
best_epoch_metrics.update(val_metrics)
|
|
442
|
+
best_epoch_metrics['epoch'] = epoch
|
|
443
|
+
|
|
444
|
+
print('- test:')
|
|
445
|
+
test_metrics, test_heaps, test_cm = run_epoch(
|
|
446
|
+
model, loader=loaders['test'], weighted=label_weighted,
|
|
447
|
+
device=device, loss_fn=loss_fn, k_extreme=log_extreme_examples)
|
|
448
|
+
test_metrics = prefix_all_keys(test_metrics, prefix='test/')
|
|
449
|
+
log_run('test', epoch, writer, label_names,
|
|
450
|
+
metrics=test_metrics, heaps=test_heaps, cm=test_cm)
|
|
451
|
+
del test_heaps
|
|
452
|
+
|
|
453
|
+
# stop training after 8 epochs without improvement
|
|
454
|
+
if epoch >= best_epoch_metrics['epoch'] + 8:
|
|
455
|
+
break
|
|
456
|
+
|
|
457
|
+
hparams_dict = {
|
|
458
|
+
'model_name': model_name,
|
|
459
|
+
'multilabel': multilabel,
|
|
460
|
+
'finetune': finetune,
|
|
461
|
+
'batch_size': batch_size,
|
|
462
|
+
'epochs': epochs
|
|
463
|
+
}
|
|
464
|
+
metric_dict = prefix_all_keys(best_epoch_metrics, prefix='hparam/')
|
|
465
|
+
writer.add_hparams(hparam_dict=hparams_dict, metric_dict=metric_dict)
|
|
466
|
+
writer.close()
|
|
467
|
+
|
|
468
|
+
# do a complete evaluation run
|
|
469
|
+
best_epoch = best_epoch_metrics['epoch']
|
|
470
|
+
evaluate_model.main(
|
|
471
|
+
params_json_path=params_json_path,
|
|
472
|
+
ckpt_path=os.path.join(logdir, f'ckpt_{best_epoch}.pt'),
|
|
473
|
+
output_dir=logdir, splits=evaluate_model.SPLITS)
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
def log_run(split: str, epoch: int, writer: tensorboard.SummaryWriter,
|
|
477
|
+
label_names: Sequence[str], metrics: MutableMapping[str, float],
|
|
478
|
+
heaps: Mapping[str, Mapping[int, list[HeapItem]]] | None,
|
|
479
|
+
cm: np.ndarray) -> None:
|
|
480
|
+
"""Logs the outputs (metrics, confusion matrix, tp/fp/fn images) from a
|
|
481
|
+
single epoch run to Tensorboard.
|
|
482
|
+
|
|
483
|
+
Args:
|
|
484
|
+
metrics: dict, keys already prefixed with {split}/
|
|
485
|
+
"""
|
|
486
|
+
per_label_recall = recall_from_confusion_matrix(cm, label_names)
|
|
487
|
+
metrics.update(prefix_all_keys(per_label_recall, f'{split}/label_recall/'))
|
|
488
|
+
|
|
489
|
+
# log metrics
|
|
490
|
+
for metric, value in metrics.items():
|
|
491
|
+
writer.add_scalar(metric, value, epoch)
|
|
492
|
+
|
|
493
|
+
# log confusion matrix
|
|
494
|
+
cm_fig = plot_utils.plot_confusion_matrix(cm, classes=label_names,
|
|
495
|
+
normalize=True)
|
|
496
|
+
cm_fig_img = fig_to_img(cm_fig)
|
|
497
|
+
writer.add_image(tag=f'confusion_matrix/{split}', img_tensor=cm_fig_img,
|
|
498
|
+
global_step=epoch, dataformats='HWC')
|
|
499
|
+
|
|
500
|
+
# log tp/fp/fn images
|
|
501
|
+
if heaps is not None:
|
|
502
|
+
for heap_type, heap_dict in heaps.items():
|
|
503
|
+
log_images_with_confidence(writer, heap_dict, label_names,
|
|
504
|
+
epoch=epoch, tag=f'{split}/{heap_type}')
|
|
505
|
+
writer.flush()
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
def log_images_with_confidence(
|
|
509
|
+
writer: tensorboard.SummaryWriter,
|
|
510
|
+
heap_dict: Mapping[int, list[HeapItem]],
|
|
511
|
+
label_names: Sequence[str],
|
|
512
|
+
epoch: int,
|
|
513
|
+
tag: str) -> None:
|
|
514
|
+
"""
|
|
515
|
+
Note: performs image normalization in-place
|
|
516
|
+
|
|
517
|
+
Args:
|
|
518
|
+
writer: tensorboard.SummaryWriter
|
|
519
|
+
heap_dict: dict, maps label_id to list of HeapItem, where each HeapItem
|
|
520
|
+
data is a tuple (img, target, top3_conf, top3_preds, img_file)
|
|
521
|
+
label_names: list of str, label names in order of label id
|
|
522
|
+
epoch: int
|
|
523
|
+
tag: str
|
|
524
|
+
"""
|
|
525
|
+
# for every image: undo normalization, clamp to [0, 1], CHW -> HWC
|
|
526
|
+
# - cannot be in-place, because the HeapItem might be in multiple heaps
|
|
527
|
+
unnormalize = tv.transforms.Normalize(mean=-MEANS/STDS, std=1.0/STDS)
|
|
528
|
+
for label_id, heap in heap_dict.items():
|
|
529
|
+
label_name = label_names[label_id]
|
|
530
|
+
|
|
531
|
+
imgs_list = []
|
|
532
|
+
for item in sorted(heap, reverse=True): # sort largest to smallest
|
|
533
|
+
img = item.data[0].float() # clamp() only supports fp32 on CPU
|
|
534
|
+
img = unnormalize(img).clamp_(0, 1).permute(1, 2, 0)
|
|
535
|
+
imgs_list.append((img, *item.data[1:]))
|
|
536
|
+
|
|
537
|
+
fig, img_files = imgs_with_confidences(imgs_list, label_names)
|
|
538
|
+
|
|
539
|
+
# writer.add_figure() has issues => using add_image() instead
|
|
540
|
+
# writer.add_figure(f'{label_name}/{tag}', fig, global_step=epoch)
|
|
541
|
+
writer.add_image(f'{label_name}/{tag}', fig_to_img(fig),
|
|
542
|
+
global_step=epoch, dataformats='HWC')
|
|
543
|
+
writer.add_text(f'{label_name}/{tag}_files', '\n\n'.join(img_files),
|
|
544
|
+
global_step=epoch)
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
def track_extreme_examples(tp_heaps: dict[int, list[HeapItem]],
|
|
548
|
+
fp_heaps: dict[int, list[HeapItem]],
|
|
549
|
+
fn_heaps: dict[int, list[HeapItem]],
|
|
550
|
+
inputs: torch.Tensor,
|
|
551
|
+
labels: torch.Tensor,
|
|
552
|
+
img_files: Sequence[str],
|
|
553
|
+
logits: torch.Tensor,
|
|
554
|
+
k: int = 5) -> None:
|
|
555
|
+
"""Updates the k most extreme true-positive (tp), false-positive (fp), and
|
|
556
|
+
false-negative (fn) examples with examples from this batch.
|
|
557
|
+
|
|
558
|
+
Each HeapItem's data attribute is a tuple of:
|
|
559
|
+
- img: torch.Tensor, shape [3, H, W], type float16, values in [0, 1]
|
|
560
|
+
- label: int
|
|
561
|
+
- top3_conf: list of float
|
|
562
|
+
- top3_preds: list of float
|
|
563
|
+
- img_file: str
|
|
564
|
+
|
|
565
|
+
Args:
|
|
566
|
+
*_heaps: dict, maps label_id (int) to heap of HeapItems
|
|
567
|
+
inputs: torch.Tensor, shape [batch_size, 3, H, W]
|
|
568
|
+
labels: torch.Tensor, shape [batch_size]
|
|
569
|
+
img_files: list of str
|
|
570
|
+
logits: torch.Tensor, shape [batch_size, num_classes]
|
|
571
|
+
k: int, number of examples to track
|
|
572
|
+
"""
|
|
573
|
+
with torch.no_grad():
|
|
574
|
+
inputs = inputs.detach().to(device='cpu', dtype=torch.float16)
|
|
575
|
+
labels_list = labels.tolist()
|
|
576
|
+
batch_probs = torch.nn.functional.softmax(logits, dim=1).cpu()
|
|
577
|
+
zipped = zip(inputs, labels_list, batch_probs, img_files) # all on CPU
|
|
578
|
+
for img, label, confs, img_file in zipped:
|
|
579
|
+
label_conf = confs[label].item()
|
|
580
|
+
|
|
581
|
+
top3_conf, top3_preds = confs.topk(3)
|
|
582
|
+
top3_conf = top3_conf.tolist()
|
|
583
|
+
top3_preds = top3_preds.tolist()
|
|
584
|
+
|
|
585
|
+
data = [img, label, top3_conf, top3_preds, img_file]
|
|
586
|
+
if top3_preds[0] == label: # true positive
|
|
587
|
+
item = HeapItem(priority=label_conf - top3_conf[1], data=data)
|
|
588
|
+
add_to_heap(tp_heaps[label], item, k=k)
|
|
589
|
+
else:
|
|
590
|
+
# false positive for top3_pred[0]
|
|
591
|
+
# false negative for label
|
|
592
|
+
item = HeapItem(priority=top3_conf[0] - label_conf, data=data)
|
|
593
|
+
add_to_heap(fp_heaps[top3_preds[0]], item, k=k)
|
|
594
|
+
add_to_heap(fn_heaps[label], item, k=k)
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
def correct(outputs: torch.Tensor, labels: torch.Tensor,
|
|
598
|
+
weights: torch.Tensor | None = None,
|
|
599
|
+
top: Sequence[int] = (1,)) -> dict[int, float]:
|
|
600
|
+
"""
|
|
601
|
+
Args:
|
|
602
|
+
outputs: torch.Tensor, shape [N, num_classes],
|
|
603
|
+
either logits (pre-softmax) or probabilities
|
|
604
|
+
labels: torch.Tensor, shape [N]
|
|
605
|
+
weights: optional torch.Tensor, shape [N]
|
|
606
|
+
top: tuple of int, list of values of k for calculating top-K accuracy
|
|
607
|
+
|
|
608
|
+
Returns: dict, maps k to (weighted) # of correct predictions @ each k
|
|
609
|
+
"""
|
|
610
|
+
with torch.no_grad():
|
|
611
|
+
# preds and labels both have shape [N, k]
|
|
612
|
+
_, preds = outputs.topk(k=max(top), dim=1, largest=True, sorted=True)
|
|
613
|
+
labels = labels.view(-1, 1).expand_as(preds)
|
|
614
|
+
|
|
615
|
+
corrects = preds.eq(labels).cumsum(dim=1) # shape [N, k]
|
|
616
|
+
if weights is None:
|
|
617
|
+
corrects = corrects.sum(dim=0) # shape [k]
|
|
618
|
+
else:
|
|
619
|
+
corrects = weights.matmul(corrects.to(weights.dtype)) # shape [k]
|
|
620
|
+
tops = {k: corrects[k - 1].item() for k in top}
|
|
621
|
+
return tops
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
def run_epoch(model: torch.nn.Module,
|
|
625
|
+
loader: torch.utils.data.DataLoader,
|
|
626
|
+
weighted: bool,
|
|
627
|
+
device: torch.device,
|
|
628
|
+
top: Sequence[int] = (1, 3),
|
|
629
|
+
loss_fn: torch.nn.Module | None = None,
|
|
630
|
+
finetune: bool = False,
|
|
631
|
+
optimizer: torch.optim.Optimizer | None = None,
|
|
632
|
+
k_extreme: int = 0
|
|
633
|
+
) -> tuple[
|
|
634
|
+
dict[str, float],
|
|
635
|
+
dict[str, dict[int, list[HeapItem]]] | None,
|
|
636
|
+
np.ndarray
|
|
637
|
+
]:
|
|
638
|
+
"""Runs for 1 epoch.
|
|
639
|
+
|
|
640
|
+
Args:
|
|
641
|
+
model: torch.nn.Module
|
|
642
|
+
loader: torch.utils.data.DataLoader
|
|
643
|
+
weighted: bool, whether to use sample weights in calculating loss and
|
|
644
|
+
accuracy
|
|
645
|
+
device: torch.device
|
|
646
|
+
top: tuple of int, list of values of k for calculating top-K accuracy
|
|
647
|
+
loss_fn: optional loss function, calculates per-example loss
|
|
648
|
+
finetune: bool, if true sets model's dropout and BN layers to eval mode
|
|
649
|
+
optimizer: optional optimizer
|
|
650
|
+
k_extreme: int, # of tp/fp/fn examples to track for each label
|
|
651
|
+
|
|
652
|
+
Returns:
|
|
653
|
+
metrics: dict, metrics from epoch, contains keys:
|
|
654
|
+
'loss': float, mean per-example loss over entire epoch,
|
|
655
|
+
only included if loss_fn is not None
|
|
656
|
+
'acc_top{k}': float, accuracy@k over the entire epoch
|
|
657
|
+
heaps: dict, keys are ['tp', 'fp', 'fn'], values are heap_dicts,
|
|
658
|
+
each heap_dict maps label_id (int) to a heap of <= 5 HeapItems with
|
|
659
|
+
data attribute (img, target, top3_conf, top3_preds, img_file)
|
|
660
|
+
- 'tp': priority is the difference between target confidence and
|
|
661
|
+
2nd highest confidence
|
|
662
|
+
- 'fp': priority is the difference between highest confidence and
|
|
663
|
+
target confidence
|
|
664
|
+
- 'fn': same as 'fp'
|
|
665
|
+
confusion_matrix: np.ndarray, shape [num_classes, num_classes],
|
|
666
|
+
C[i, j] = # of samples with true label i, predicted as label j
|
|
667
|
+
"""
|
|
668
|
+
if optimizer is not None:
|
|
669
|
+
assert loss_fn is not None
|
|
670
|
+
|
|
671
|
+
# if evaluating or finetuning, set dropout and BN layers to eval mode
|
|
672
|
+
model.train(optimizer is not None and not finetune)
|
|
673
|
+
|
|
674
|
+
if loss_fn is not None:
|
|
675
|
+
losses = AverageMeter()
|
|
676
|
+
accuracies_topk = {k: AverageMeter() for k in top} # acc@k
|
|
677
|
+
|
|
678
|
+
# for each label, track k_extreme most-confident and least-confident images
|
|
679
|
+
if k_extreme > 0:
|
|
680
|
+
tp_heaps: dict[int, list[HeapItem]] = defaultdict(list)
|
|
681
|
+
fp_heaps: dict[int, list[HeapItem]] = defaultdict(list)
|
|
682
|
+
fn_heaps: dict[int, list[HeapItem]] = defaultdict(list)
|
|
683
|
+
|
|
684
|
+
all_labels = np.zeros(len(loader.dataset), dtype=np.int32)
|
|
685
|
+
all_preds = np.zeros_like(all_labels)
|
|
686
|
+
end_i = 0
|
|
687
|
+
|
|
688
|
+
tqdm_loader = tqdm.tqdm(loader)
|
|
689
|
+
with torch.set_grad_enabled(optimizer is not None):
|
|
690
|
+
for batch in tqdm_loader:
|
|
691
|
+
if weighted:
|
|
692
|
+
inputs, labels, img_files, weights = batch
|
|
693
|
+
weights = weights.to(device, non_blocking=True)
|
|
694
|
+
else:
|
|
695
|
+
# even if batch contains sample weights, don't use them
|
|
696
|
+
inputs, labels, img_files = batch[0:3]
|
|
697
|
+
weights = None
|
|
698
|
+
|
|
699
|
+
inputs = inputs.to(device, non_blocking=True)
|
|
700
|
+
|
|
701
|
+
batch_size = labels.size(0)
|
|
702
|
+
start_i = end_i
|
|
703
|
+
end_i = start_i + batch_size
|
|
704
|
+
all_labels[start_i:end_i] = labels
|
|
705
|
+
|
|
706
|
+
desc = []
|
|
707
|
+
labels = labels.to(device, non_blocking=True)
|
|
708
|
+
outputs = model(inputs)
|
|
709
|
+
all_preds[start_i:end_i] = outputs.detach().argmax(dim=1).cpu()
|
|
710
|
+
|
|
711
|
+
if loss_fn is not None:
|
|
712
|
+
loss = loss_fn(outputs, labels)
|
|
713
|
+
if weights is not None:
|
|
714
|
+
loss *= weights
|
|
715
|
+
loss = loss.mean()
|
|
716
|
+
losses.update(loss.item(), n=batch_size)
|
|
717
|
+
desc.append(f'Loss {losses.val:.4f} ({losses.avg:.4f})')
|
|
718
|
+
if optimizer is not None:
|
|
719
|
+
optimizer.zero_grad()
|
|
720
|
+
loss.backward()
|
|
721
|
+
optimizer.step()
|
|
722
|
+
|
|
723
|
+
top_correct = correct(outputs, labels, weights=weights, top=top)
|
|
724
|
+
for k, acc in accuracies_topk.items():
|
|
725
|
+
acc.update(top_correct[k] * (100. / batch_size), n=batch_size)
|
|
726
|
+
desc.append(f'Acc@{k} {acc.val:.3f} ({acc.avg:.3f})')
|
|
727
|
+
tqdm_loader.set_description(' '.join(desc))
|
|
728
|
+
|
|
729
|
+
if k_extreme > 0:
|
|
730
|
+
track_extreme_examples(tp_heaps, fp_heaps, fn_heaps, inputs,
|
|
731
|
+
labels, img_files, outputs, k=k_extreme)
|
|
732
|
+
|
|
733
|
+
num_classes = outputs.size(1)
|
|
734
|
+
confusion_matrix = sklearn.metrics.confusion_matrix(
|
|
735
|
+
all_labels, all_preds, labels=np.arange(num_classes))
|
|
736
|
+
|
|
737
|
+
metrics = {}
|
|
738
|
+
if loss_fn is not None:
|
|
739
|
+
metrics['loss'] = losses.avg
|
|
740
|
+
for k, acc in accuracies_topk.items():
|
|
741
|
+
metrics[f'acc_top{k}'] = acc.avg
|
|
742
|
+
heaps = None
|
|
743
|
+
if k_extreme > 0:
|
|
744
|
+
heaps = {'tp': tp_heaps, 'fp': fp_heaps, 'fn': fn_heaps}
|
|
745
|
+
return metrics, heaps, confusion_matrix
|
|
746
|
+
|
|
747
|
+
|
|
748
|
+
def _parse_args() -> argparse.Namespace:
|
|
749
|
+
"""Parses arguments."""
|
|
750
|
+
parser = argparse.ArgumentParser(
|
|
751
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
752
|
+
description='Trains classifier.')
|
|
753
|
+
parser.add_argument(
|
|
754
|
+
'dataset_dir',
|
|
755
|
+
help='path to directory containing: 1) classification dataset CSV, '
|
|
756
|
+
'2) label index JSON, 3) splits JSON')
|
|
757
|
+
parser.add_argument(
|
|
758
|
+
'cropped_images_dir',
|
|
759
|
+
help='path to local directory where image crops are saved')
|
|
760
|
+
parser.add_argument(
|
|
761
|
+
'--multilabel', action='store_true',
|
|
762
|
+
help='for multi-label, multi-class classification')
|
|
763
|
+
parser.add_argument(
|
|
764
|
+
'-m', '--model-name', default='efficientnet-b0',
|
|
765
|
+
choices=VALID_MODELS,
|
|
766
|
+
help='which EfficientNet or Resnet model')
|
|
767
|
+
parser.add_argument(
|
|
768
|
+
'--pretrained', nargs='?', const=True, default=False,
|
|
769
|
+
help='start with ImageNet pretrained model or a specific checkpoint')
|
|
770
|
+
parser.add_argument(
|
|
771
|
+
'--finetune', type=int, default=0,
|
|
772
|
+
help='only fine tune the final fully-connected layer for the first '
|
|
773
|
+
'<finetune> epochs')
|
|
774
|
+
parser.add_argument(
|
|
775
|
+
'--label-weighted', action='store_true',
|
|
776
|
+
help='weight training samples to balance labels')
|
|
777
|
+
parser.add_argument(
|
|
778
|
+
'--weight-by-detection-conf', nargs='?', const=True, default=False,
|
|
779
|
+
help='weight training examples by detection confidence. '
|
|
780
|
+
'Optionally takes a .npz file for isotonic calibration.')
|
|
781
|
+
parser.add_argument(
|
|
782
|
+
'--epochs', type=int, default=0,
|
|
783
|
+
help='number of epochs for training, 0 for eval-only')
|
|
784
|
+
parser.add_argument(
|
|
785
|
+
'--batch-size', type=int, default=256,
|
|
786
|
+
help='batch size for both training and eval')
|
|
787
|
+
parser.add_argument(
|
|
788
|
+
'--lr', type=float,
|
|
789
|
+
help='initial learning rate, defaults to (0.016 * batch_size / 256)')
|
|
790
|
+
parser.add_argument(
|
|
791
|
+
'--weight-decay', type=float, default=1e-5,
|
|
792
|
+
help='weight decay')
|
|
793
|
+
parser.add_argument(
|
|
794
|
+
'--num-workers', type=int, default=8,
|
|
795
|
+
help='# of workers for data loading')
|
|
796
|
+
parser.add_argument(
|
|
797
|
+
'--logdir', default='.',
|
|
798
|
+
help='directory where TensorBoard logs and a params file are saved')
|
|
799
|
+
parser.add_argument(
|
|
800
|
+
'--log-extreme-examples', type=int, default=0,
|
|
801
|
+
help='# of tp/fp/fn examples to log for each label and split per epoch')
|
|
802
|
+
parser.add_argument(
|
|
803
|
+
'--seed', type=int,
|
|
804
|
+
help='random seed')
|
|
805
|
+
return parser.parse_args()
|
|
806
|
+
|
|
807
|
+
|
|
808
|
+
if __name__ == '__main__':
|
|
809
|
+
args = _parse_args()
|
|
810
|
+
if args.lr is None:
|
|
811
|
+
args.lr = 0.016 * args.batch_size / 256 # based on TF models repo
|
|
812
|
+
main(dataset_dir=args.dataset_dir,
|
|
813
|
+
cropped_images_dir=args.cropped_images_dir,
|
|
814
|
+
multilabel=args.multilabel,
|
|
815
|
+
model_name=args.model_name,
|
|
816
|
+
pretrained=args.pretrained,
|
|
817
|
+
finetune=args.finetune,
|
|
818
|
+
label_weighted=args.label_weighted,
|
|
819
|
+
weight_by_detection_conf=args.weight_by_detection_conf,
|
|
820
|
+
epochs=args.epochs,
|
|
821
|
+
batch_size=args.batch_size,
|
|
822
|
+
lr=args.lr,
|
|
823
|
+
weight_decay=args.weight_decay,
|
|
824
|
+
num_workers=args.num_workers,
|
|
825
|
+
logdir=args.logdir,
|
|
826
|
+
log_extreme_examples=args.log_extreme_examples,
|
|
827
|
+
seed=args.seed)
|