megadetector 10.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- megadetector/__init__.py +0 -0
- megadetector/api/__init__.py +0 -0
- megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
- megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
- megadetector/classification/__init__.py +0 -0
- megadetector/classification/aggregate_classifier_probs.py +108 -0
- megadetector/classification/analyze_failed_images.py +227 -0
- megadetector/classification/cache_batchapi_outputs.py +198 -0
- megadetector/classification/create_classification_dataset.py +626 -0
- megadetector/classification/crop_detections.py +516 -0
- megadetector/classification/csv_to_json.py +226 -0
- megadetector/classification/detect_and_crop.py +853 -0
- megadetector/classification/efficientnet/__init__.py +9 -0
- megadetector/classification/efficientnet/model.py +415 -0
- megadetector/classification/efficientnet/utils.py +608 -0
- megadetector/classification/evaluate_model.py +520 -0
- megadetector/classification/identify_mislabeled_candidates.py +152 -0
- megadetector/classification/json_to_azcopy_list.py +63 -0
- megadetector/classification/json_validator.py +696 -0
- megadetector/classification/map_classification_categories.py +276 -0
- megadetector/classification/merge_classification_detection_output.py +509 -0
- megadetector/classification/prepare_classification_script.py +194 -0
- megadetector/classification/prepare_classification_script_mc.py +228 -0
- megadetector/classification/run_classifier.py +287 -0
- megadetector/classification/save_mislabeled.py +110 -0
- megadetector/classification/train_classifier.py +827 -0
- megadetector/classification/train_classifier_tf.py +725 -0
- megadetector/classification/train_utils.py +323 -0
- megadetector/data_management/__init__.py +0 -0
- megadetector/data_management/animl_to_md.py +161 -0
- megadetector/data_management/annotations/__init__.py +0 -0
- megadetector/data_management/annotations/annotation_constants.py +33 -0
- megadetector/data_management/camtrap_dp_to_coco.py +270 -0
- megadetector/data_management/cct_json_utils.py +566 -0
- megadetector/data_management/cct_to_md.py +184 -0
- megadetector/data_management/cct_to_wi.py +293 -0
- megadetector/data_management/coco_to_labelme.py +284 -0
- megadetector/data_management/coco_to_yolo.py +701 -0
- megadetector/data_management/databases/__init__.py +0 -0
- megadetector/data_management/databases/add_width_and_height_to_db.py +107 -0
- megadetector/data_management/databases/combine_coco_camera_traps_files.py +210 -0
- megadetector/data_management/databases/integrity_check_json_db.py +563 -0
- megadetector/data_management/databases/subset_json_db.py +195 -0
- megadetector/data_management/generate_crops_from_cct.py +200 -0
- megadetector/data_management/get_image_sizes.py +164 -0
- megadetector/data_management/labelme_to_coco.py +559 -0
- megadetector/data_management/labelme_to_yolo.py +349 -0
- megadetector/data_management/lila/__init__.py +0 -0
- megadetector/data_management/lila/create_lila_blank_set.py +556 -0
- megadetector/data_management/lila/create_lila_test_set.py +192 -0
- megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
- megadetector/data_management/lila/download_lila_subset.py +182 -0
- megadetector/data_management/lila/generate_lila_per_image_labels.py +777 -0
- megadetector/data_management/lila/get_lila_annotation_counts.py +174 -0
- megadetector/data_management/lila/get_lila_image_counts.py +112 -0
- megadetector/data_management/lila/lila_common.py +319 -0
- megadetector/data_management/lila/test_lila_metadata_urls.py +164 -0
- megadetector/data_management/mewc_to_md.py +344 -0
- megadetector/data_management/ocr_tools.py +873 -0
- megadetector/data_management/read_exif.py +964 -0
- megadetector/data_management/remap_coco_categories.py +195 -0
- megadetector/data_management/remove_exif.py +156 -0
- megadetector/data_management/rename_images.py +194 -0
- megadetector/data_management/resize_coco_dataset.py +665 -0
- megadetector/data_management/speciesnet_to_md.py +41 -0
- megadetector/data_management/wi_download_csv_to_coco.py +247 -0
- megadetector/data_management/yolo_output_to_md_output.py +594 -0
- megadetector/data_management/yolo_to_coco.py +984 -0
- megadetector/data_management/zamba_to_md.py +188 -0
- megadetector/detection/__init__.py +0 -0
- megadetector/detection/change_detection.py +840 -0
- megadetector/detection/process_video.py +479 -0
- megadetector/detection/pytorch_detector.py +1451 -0
- megadetector/detection/run_detector.py +1267 -0
- megadetector/detection/run_detector_batch.py +2172 -0
- megadetector/detection/run_inference_with_yolov5_val.py +1314 -0
- megadetector/detection/run_md_and_speciesnet.py +1604 -0
- megadetector/detection/run_tiled_inference.py +1044 -0
- megadetector/detection/tf_detector.py +209 -0
- megadetector/detection/video_utils.py +1379 -0
- megadetector/postprocessing/__init__.py +0 -0
- megadetector/postprocessing/add_max_conf.py +72 -0
- megadetector/postprocessing/categorize_detections_by_size.py +166 -0
- megadetector/postprocessing/classification_postprocessing.py +1943 -0
- megadetector/postprocessing/combine_batch_outputs.py +249 -0
- megadetector/postprocessing/compare_batch_results.py +2110 -0
- megadetector/postprocessing/convert_output_format.py +403 -0
- megadetector/postprocessing/create_crop_folder.py +629 -0
- megadetector/postprocessing/detector_calibration.py +570 -0
- megadetector/postprocessing/generate_csv_report.py +522 -0
- megadetector/postprocessing/load_api_results.py +223 -0
- megadetector/postprocessing/md_to_coco.py +428 -0
- megadetector/postprocessing/md_to_labelme.py +351 -0
- megadetector/postprocessing/md_to_wi.py +41 -0
- megadetector/postprocessing/merge_detections.py +392 -0
- megadetector/postprocessing/postprocess_batch_results.py +2140 -0
- megadetector/postprocessing/remap_detection_categories.py +226 -0
- megadetector/postprocessing/render_detection_confusion_matrix.py +677 -0
- megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +206 -0
- megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +82 -0
- megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1665 -0
- megadetector/postprocessing/separate_detections_into_folders.py +795 -0
- megadetector/postprocessing/subset_json_detector_output.py +964 -0
- megadetector/postprocessing/top_folders_to_bottom.py +238 -0
- megadetector/postprocessing/validate_batch_results.py +332 -0
- megadetector/taxonomy_mapping/__init__.py +0 -0
- megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
- megadetector/taxonomy_mapping/map_new_lila_datasets.py +211 -0
- megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +165 -0
- megadetector/taxonomy_mapping/preview_lila_taxonomy.py +543 -0
- megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
- megadetector/taxonomy_mapping/simple_image_download.py +231 -0
- megadetector/taxonomy_mapping/species_lookup.py +1008 -0
- megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
- megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
- megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
- megadetector/tests/__init__.py +0 -0
- megadetector/tests/test_nms_synthetic.py +335 -0
- megadetector/utils/__init__.py +0 -0
- megadetector/utils/ct_utils.py +1857 -0
- megadetector/utils/directory_listing.py +199 -0
- megadetector/utils/extract_frames_from_video.py +307 -0
- megadetector/utils/gpu_test.py +125 -0
- megadetector/utils/md_tests.py +2072 -0
- megadetector/utils/path_utils.py +2872 -0
- megadetector/utils/process_utils.py +172 -0
- megadetector/utils/split_locations_into_train_val.py +237 -0
- megadetector/utils/string_utils.py +234 -0
- megadetector/utils/url_utils.py +825 -0
- megadetector/utils/wi_platform_utils.py +968 -0
- megadetector/utils/wi_taxonomy_utils.py +1766 -0
- megadetector/utils/write_html_image_list.py +239 -0
- megadetector/visualization/__init__.py +0 -0
- megadetector/visualization/plot_utils.py +309 -0
- megadetector/visualization/render_images_with_thumbnails.py +243 -0
- megadetector/visualization/visualization_utils.py +1973 -0
- megadetector/visualization/visualize_db.py +630 -0
- megadetector/visualization/visualize_detector_output.py +498 -0
- megadetector/visualization/visualize_video_output.py +705 -0
- megadetector-10.0.15.dist-info/METADATA +115 -0
- megadetector-10.0.15.dist-info/RECORD +147 -0
- megadetector-10.0.15.dist-info/WHEEL +5 -0
- megadetector-10.0.15.dist-info/licenses/LICENSE +19 -0
- megadetector-10.0.15.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,725 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
train_classifier_tf.py
|
|
4
|
+
|
|
5
|
+
Train an EfficientNet classifier.
|
|
6
|
+
|
|
7
|
+
Currently the implementation of multi-label multi-class classification is
|
|
8
|
+
non-functional.
|
|
9
|
+
|
|
10
|
+
During training, start tensorboard from within the classification/ directory:
|
|
11
|
+
tensorboard --logdir run --bind_all --samples_per_plugin scalars=0,images=0
|
|
12
|
+
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
#%% Imports and constants
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import os
|
|
20
|
+
import json
|
|
21
|
+
import uuid
|
|
22
|
+
import argparse
|
|
23
|
+
|
|
24
|
+
import tqdm
|
|
25
|
+
import numpy as np
|
|
26
|
+
import sklearn.metrics
|
|
27
|
+
import tensorflow as tf
|
|
28
|
+
|
|
29
|
+
from collections import defaultdict
|
|
30
|
+
from collections.abc import Callable, Mapping, MutableMapping, Sequence
|
|
31
|
+
from datetime import datetime
|
|
32
|
+
from typing import Any, Optional
|
|
33
|
+
from tensorboard.plugins.hparams import api as hp
|
|
34
|
+
|
|
35
|
+
from megadetector.classification.train_utils import (
|
|
36
|
+
HeapItem, recall_from_confusion_matrix, add_to_heap, fig_to_img,
|
|
37
|
+
imgs_with_confidences, load_dataset_csv, prefix_all_keys)
|
|
38
|
+
from megadetector.visualization import plot_utils
|
|
39
|
+
|
|
40
|
+
AUTOTUNE = tf.data.experimental.AUTOTUNE
|
|
41
|
+
|
|
42
|
+
# match pytorch EfficientNet model names
|
|
43
|
+
EFFICIENTNET_MODELS: Mapping[str, Mapping[str, Any]] = {
|
|
44
|
+
'efficientnet-b0': dict(cls='EfficientNetB0', img_size=224, dropout=0.2),
|
|
45
|
+
'efficientnet-b1': dict(cls='EfficientNetB1', img_size=240, dropout=0.2),
|
|
46
|
+
'efficientnet-b2': dict(cls='EfficientNetB2', img_size=260, dropout=0.3),
|
|
47
|
+
'efficientnet-b3': dict(cls='EfficientNetB3', img_size=300, dropout=0.3),
|
|
48
|
+
'efficientnet-b4': dict(cls='EfficientNetB4', img_size=380, dropout=0.4),
|
|
49
|
+
'efficientnet-b5': dict(cls='EfficientNetB5', img_size=456, dropout=0.4),
|
|
50
|
+
'efficientnet-b6': dict(cls='EfficientNetB6', img_size=528, dropout=0.5),
|
|
51
|
+
'efficientnet-b7': dict(cls='EfficientNetB7', img_size=600, dropout=0.5)
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
#%% Example usage
|
|
56
|
+
|
|
57
|
+
"""
|
|
58
|
+
python train_classifier_tf.py run_idfg /ssd/crops_sq \
|
|
59
|
+
-m "efficientnet-b0" --pretrained --finetune --label-weighted \
|
|
60
|
+
--epochs 50 --batch-size 512 --lr 1e-4 \
|
|
61
|
+
--seed 123 \
|
|
62
|
+
--logdir run_idfg
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
#%% Support functions
|
|
67
|
+
|
|
68
|
+
def create_dataset(
|
|
69
|
+
img_files: Sequence[str],
|
|
70
|
+
labels: Sequence[Any],
|
|
71
|
+
sample_weights: Optional[Sequence[float]] = None,
|
|
72
|
+
img_base_dir: str = '',
|
|
73
|
+
transform: Optional[Callable[[tf.Tensor], Any]] = None,
|
|
74
|
+
target_transform: Optional[Callable[[Any], Any]] = None,
|
|
75
|
+
cache: bool | str = False
|
|
76
|
+
) -> tf.data.Dataset:
|
|
77
|
+
"""
|
|
78
|
+
Create a tf.data.Dataset.
|
|
79
|
+
|
|
80
|
+
The dataset returns elements (img, label, img_file, sample_weight) if
|
|
81
|
+
sample_weights is not None, or (img, label, img_file) if
|
|
82
|
+
sample_weights=None.
|
|
83
|
+
img: tf.Tensor, shape [H, W, 3], type uint8
|
|
84
|
+
label: tf.Tensor
|
|
85
|
+
img_file: tf.Tensor, scalar, type str
|
|
86
|
+
sample_weight: tf.Tensor, scalar, type float32
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
img_files: list of str, relative paths from img_base_dir
|
|
90
|
+
labels: list of int if multilabel=False
|
|
91
|
+
sample_weights: optional list of float
|
|
92
|
+
img_base_dir: str, base directory for images
|
|
93
|
+
transform: optional transform to apply to a single uint8 JPEG image
|
|
94
|
+
target_transform: optional transform to apply to a single label
|
|
95
|
+
cache: bool or str, cache images in memory if True, cache images to
|
|
96
|
+
a file on disk if a str
|
|
97
|
+
|
|
98
|
+
Returns: tf.data.Dataset
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
# images dataset
|
|
102
|
+
img_ds = tf.data.Dataset.from_tensor_slices(img_files)
|
|
103
|
+
img_ds = img_ds.map(lambda p: tf.io.read_file(img_base_dir + os.sep + p),
|
|
104
|
+
num_parallel_calls=AUTOTUNE)
|
|
105
|
+
|
|
106
|
+
# for smaller disk / memory usage, we cache the raw JPEG bytes instead
|
|
107
|
+
# of the decoded Tensor
|
|
108
|
+
if isinstance(cache, str):
|
|
109
|
+
img_ds = img_ds.cache(cache)
|
|
110
|
+
elif cache:
|
|
111
|
+
img_ds = img_ds.cache()
|
|
112
|
+
|
|
113
|
+
# convert JPEG bytes to a 3D uint8 Tensor
|
|
114
|
+
# keras EfficientNet already includes normalization from [0, 255] to [0, 1],
|
|
115
|
+
# so we don't need to do that here
|
|
116
|
+
img_ds = img_ds.map(lambda img: tf.io.decode_jpeg(img, channels=3))
|
|
117
|
+
|
|
118
|
+
if transform:
|
|
119
|
+
img_ds = img_ds.map(transform, num_parallel_calls=AUTOTUNE)
|
|
120
|
+
|
|
121
|
+
# labels dataset
|
|
122
|
+
labels_ds = tf.data.Dataset.from_tensor_slices(labels)
|
|
123
|
+
if target_transform:
|
|
124
|
+
labels_ds = labels_ds.map(target_transform, num_parallel_calls=AUTOTUNE)
|
|
125
|
+
|
|
126
|
+
# img_files dataset
|
|
127
|
+
img_files_ds = tf.data.Dataset.from_tensor_slices(img_files)
|
|
128
|
+
|
|
129
|
+
if sample_weights is None:
|
|
130
|
+
return tf.data.Dataset.zip((img_ds, labels_ds, img_files_ds))
|
|
131
|
+
|
|
132
|
+
# weights dataset
|
|
133
|
+
weights_ds = tf.data.Dataset.from_tensor_slices(sample_weights)
|
|
134
|
+
return tf.data.Dataset.zip((img_ds, labels_ds, img_files_ds, weights_ds))
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def create_dataloaders(
|
|
138
|
+
dataset_csv_path: str,
|
|
139
|
+
label_index_json_path: str,
|
|
140
|
+
splits_json_path: str,
|
|
141
|
+
cropped_images_dir: str,
|
|
142
|
+
img_size: int,
|
|
143
|
+
multilabel: bool,
|
|
144
|
+
label_weighted: bool,
|
|
145
|
+
weight_by_detection_conf: bool | str,
|
|
146
|
+
batch_size: int,
|
|
147
|
+
augment_train: bool,
|
|
148
|
+
cache_splits: Sequence[str]
|
|
149
|
+
) -> tuple[dict[str, tf.data.Dataset], list[str]]:
|
|
150
|
+
"""
|
|
151
|
+
Args:
|
|
152
|
+
dataset_csv_path: str, path to CSV file with columns
|
|
153
|
+
['dataset', 'location', 'label'], where label is a comma-delimited
|
|
154
|
+
list of labels
|
|
155
|
+
splits_json_path: str, path to JSON file
|
|
156
|
+
augment_train: bool, whether to shuffle/augment the training set
|
|
157
|
+
cache_splits: list of str, splits to cache
|
|
158
|
+
training set is cached at /mnt/tempds/random_file_name
|
|
159
|
+
validation and test sets are cached in memory
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
datasets: dict, maps split to DataLoader
|
|
163
|
+
label_names: list of str, label names in order of label id
|
|
164
|
+
"""
|
|
165
|
+
|
|
166
|
+
df, label_names, split_to_locs = load_dataset_csv(
|
|
167
|
+
dataset_csv_path, label_index_json_path, splits_json_path,
|
|
168
|
+
multilabel=multilabel, label_weighted=label_weighted,
|
|
169
|
+
weight_by_detection_conf=weight_by_detection_conf)
|
|
170
|
+
|
|
171
|
+
# define the transforms
|
|
172
|
+
|
|
173
|
+
# efficientnet data preprocessing:
|
|
174
|
+
# - train:
|
|
175
|
+
# 1) random crop: aspect_ratio_range=(0.75, 1.33), area_range=(0.08, 1.0)
|
|
176
|
+
# 2) bicubic resize to img_size
|
|
177
|
+
# 3) random horizontal flip
|
|
178
|
+
# - test:
|
|
179
|
+
# 1) center crop
|
|
180
|
+
# 2) bicubic resize to img_size
|
|
181
|
+
|
|
182
|
+
@tf.function
|
|
183
|
+
def train_transform(img: tf.Tensor) -> tf.Tensor:
|
|
184
|
+
"""Returns: tf.Tensor, shape [img_size, img_size, C], type float32"""
|
|
185
|
+
img = tf.image.resize_with_pad(img, img_size, img_size,
|
|
186
|
+
method=tf.image.ResizeMethod.BICUBIC)
|
|
187
|
+
img = tf.image.random_flip_left_right(img)
|
|
188
|
+
img = tf.image.random_brightness(img, max_delta=0.25)
|
|
189
|
+
img = tf.image.random_contrast(img, lower=0.75, upper=1.25)
|
|
190
|
+
img = tf.image.random_saturation(img, lower=0.75, upper=1.25)
|
|
191
|
+
return img
|
|
192
|
+
|
|
193
|
+
@tf.function
|
|
194
|
+
def test_transform(img: tf.Tensor) -> tf.Tensor:
|
|
195
|
+
"""Returns: tf.Tensor, shape [img_size, img_size, C], type float32"""
|
|
196
|
+
img = tf.image.resize_with_pad(img, img_size, img_size,
|
|
197
|
+
method=tf.image.ResizeMethod.BICUBIC)
|
|
198
|
+
return img
|
|
199
|
+
|
|
200
|
+
dataloaders = {}
|
|
201
|
+
for split, locs in split_to_locs.items():
|
|
202
|
+
is_train = (split == 'train') and augment_train
|
|
203
|
+
split_df = df[df['dataset_location'].isin(locs)]
|
|
204
|
+
|
|
205
|
+
weights = None
|
|
206
|
+
if label_weighted or weight_by_detection_conf:
|
|
207
|
+
# weights sums to:
|
|
208
|
+
# - if weight_by_detection_conf: (# images in split - conf delta)
|
|
209
|
+
# - otherwise: (# images in split)
|
|
210
|
+
weights = split_df['weights'].tolist()
|
|
211
|
+
if not weight_by_detection_conf:
|
|
212
|
+
assert np.isclose(sum(weights), len(split_df))
|
|
213
|
+
|
|
214
|
+
cache: bool | str = (split in cache_splits)
|
|
215
|
+
if split == 'train' and 'train' in cache_splits:
|
|
216
|
+
unique_filename = str(uuid.uuid4())
|
|
217
|
+
os.makedirs('/mnt/tempds/', exist_ok=True)
|
|
218
|
+
cache = f'/mnt/tempds/{unique_filename}'
|
|
219
|
+
|
|
220
|
+
ds = create_dataset(
|
|
221
|
+
img_files=split_df['path'].tolist(),
|
|
222
|
+
labels=split_df['label_index'].tolist(),
|
|
223
|
+
sample_weights=weights,
|
|
224
|
+
img_base_dir=cropped_images_dir,
|
|
225
|
+
transform=train_transform if is_train else test_transform,
|
|
226
|
+
target_transform=None,
|
|
227
|
+
cache=cache)
|
|
228
|
+
if is_train:
|
|
229
|
+
ds = ds.shuffle(1000, reshuffle_each_iteration=True)
|
|
230
|
+
ds = ds.batch(batch_size).prefetch(buffer_size=AUTOTUNE)
|
|
231
|
+
dataloaders[split] = ds
|
|
232
|
+
|
|
233
|
+
return dataloaders, label_names
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def build_model(model_name: str, num_classes: int, img_size: int,
|
|
237
|
+
pretrained: bool, finetune: bool) -> tf.keras.Model:
|
|
238
|
+
"""
|
|
239
|
+
Creates a model with an EfficientNet base.
|
|
240
|
+
"""
|
|
241
|
+
|
|
242
|
+
class_name = EFFICIENTNET_MODELS[model_name]['cls']
|
|
243
|
+
dropout = EFFICIENTNET_MODELS[model_name]['dropout']
|
|
244
|
+
|
|
245
|
+
model_class = tf.keras.applications.__dict__[class_name]
|
|
246
|
+
weights = 'imagenet' if pretrained else None
|
|
247
|
+
inputs = tf.keras.layers.Input(shape=(img_size, img_size, 3))
|
|
248
|
+
base_model = model_class(
|
|
249
|
+
input_tensor=inputs, weights=weights, include_top=False, pooling='avg')
|
|
250
|
+
|
|
251
|
+
if finetune:
|
|
252
|
+
# freeze the base model's weights, including BatchNorm statistics
|
|
253
|
+
# https://www.tensorflow.org/guide/keras/transfer_learning#fine-tuning
|
|
254
|
+
base_model.trainable = False
|
|
255
|
+
|
|
256
|
+
# rebuild output
|
|
257
|
+
x = tf.keras.layers.Dropout(dropout, name='top_dropout')(base_model.output)
|
|
258
|
+
outputs = tf.keras.layers.Dense(
|
|
259
|
+
num_classes,
|
|
260
|
+
kernel_initializer=tf.keras.initializers.VarianceScaling(
|
|
261
|
+
scale=1. / 3., mode='fan_out', distribution='uniform'),
|
|
262
|
+
name='logits')(x)
|
|
263
|
+
model = tf.keras.Model(inputs, outputs, name='complete_model')
|
|
264
|
+
model.base_model = base_model # cache this so that we can turn off finetune
|
|
265
|
+
return model
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def log_images_with_confidence(
|
|
269
|
+
heap_dict: Mapping[int, list[HeapItem]],
|
|
270
|
+
label_names: Sequence[str],
|
|
271
|
+
epoch: int,
|
|
272
|
+
tag: str) -> None:
|
|
273
|
+
"""
|
|
274
|
+
Args:
|
|
275
|
+
heap_dict: dict, maps label_id to list of HeapItem, where each HeapItem
|
|
276
|
+
data is a list [img, target, top3_conf, top3_preds, img_file],
|
|
277
|
+
and img is a tf.Tensor of shape [H, W, 3]
|
|
278
|
+
label_names: list of str, label names in order of label id
|
|
279
|
+
epoch: int
|
|
280
|
+
tag: str
|
|
281
|
+
"""
|
|
282
|
+
|
|
283
|
+
for label_id, heap in heap_dict.items():
|
|
284
|
+
label_name = label_names[label_id]
|
|
285
|
+
|
|
286
|
+
sorted_heap = sorted(heap, reverse=True) # sort largest to smallest
|
|
287
|
+
imgs_list = [item.data for item in sorted_heap]
|
|
288
|
+
fig, img_files = imgs_with_confidences(imgs_list, label_names)
|
|
289
|
+
|
|
290
|
+
# tf.summary.image requires input of shape [N, H, W, C]
|
|
291
|
+
fig_img = tf.convert_to_tensor(fig_to_img(fig)[np.newaxis, ...])
|
|
292
|
+
tf.summary.image(f'{label_name}/{tag}', fig_img, step=epoch)
|
|
293
|
+
tf.summary.text(f'{label_name}/{tag}_files', '\n\n'.join(img_files),
|
|
294
|
+
step=epoch)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def track_extreme_examples(tp_heaps: dict[int, list[HeapItem]],
|
|
298
|
+
fp_heaps: dict[int, list[HeapItem]],
|
|
299
|
+
fn_heaps: dict[int, list[HeapItem]],
|
|
300
|
+
inputs: tf.Tensor,
|
|
301
|
+
labels: tf.Tensor,
|
|
302
|
+
img_files: tf.Tensor,
|
|
303
|
+
logits: tf.Tensor) -> None:
|
|
304
|
+
"""
|
|
305
|
+
Updates the 5 most extreme true-positive (tp), false-positive (fp), and
|
|
306
|
+
false-negative (fn) examples with examples from this batch.
|
|
307
|
+
|
|
308
|
+
Each HeapItem's data attribute is a tuple with:
|
|
309
|
+
- img: np.ndarray, shape [H, W, 3], type uint8
|
|
310
|
+
- label: int
|
|
311
|
+
- top3_conf: list of float
|
|
312
|
+
- top3_preds: list of float
|
|
313
|
+
- img_file: str
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
*_heaps: dict, maps label_id (int) to heap of HeapItems
|
|
317
|
+
inputs: tf.Tensor, shape [batch_size, H, W, 3], type float32
|
|
318
|
+
labels: tf.Tensor, shape [batch_size]
|
|
319
|
+
img_files: tf.Tensor, shape [batch_size], type tf.string
|
|
320
|
+
logits: tf.Tensor, shape [batch_size, num_classes]
|
|
321
|
+
"""
|
|
322
|
+
|
|
323
|
+
labels = labels.numpy().tolist()
|
|
324
|
+
inputs = inputs.numpy().astype(np.uint8)
|
|
325
|
+
img_files = img_files.numpy().astype(str).tolist()
|
|
326
|
+
batch_probs = tf.nn.softmax(logits, axis=1)
|
|
327
|
+
iterable = zip(labels, inputs, img_files, batch_probs)
|
|
328
|
+
for label, img, img_file, confs in iterable:
|
|
329
|
+
label_conf = confs[label].numpy().item()
|
|
330
|
+
|
|
331
|
+
top3_conf, top3_preds = tf.math.top_k(confs, k=3, sorted=True)
|
|
332
|
+
top3_conf = top3_conf.numpy().tolist()
|
|
333
|
+
top3_preds = top3_preds.numpy().tolist()
|
|
334
|
+
|
|
335
|
+
data = (img, label, top3_conf, top3_preds, img_file)
|
|
336
|
+
if top3_preds[0] == label: # true positive
|
|
337
|
+
item = HeapItem(priority=label_conf - top3_conf[1], data=data)
|
|
338
|
+
add_to_heap(tp_heaps[label], item, k=5)
|
|
339
|
+
else:
|
|
340
|
+
# false positive for top3_pred[0]
|
|
341
|
+
# false negative for label
|
|
342
|
+
item = HeapItem(priority=top3_conf[0] - label_conf, data=data)
|
|
343
|
+
add_to_heap(fp_heaps[top3_preds[0]], item, k=5)
|
|
344
|
+
add_to_heap(fn_heaps[label], item, k=5)
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def run_epoch(model: tf.keras.Model,
|
|
348
|
+
loader: tf.data.Dataset,
|
|
349
|
+
weighted: bool,
|
|
350
|
+
top: Sequence[int] = (1, 3),
|
|
351
|
+
loss_fn: Optional[tf.keras.losses.Loss] = None,
|
|
352
|
+
weight_decay: float = 0,
|
|
353
|
+
finetune: bool = False,
|
|
354
|
+
optimizer: Optional[tf.keras.optimizers.Optimizer] = None,
|
|
355
|
+
return_extreme_images: bool = False
|
|
356
|
+
) -> tuple[
|
|
357
|
+
dict[str, float],
|
|
358
|
+
dict[str, dict[int, list[HeapItem]]],
|
|
359
|
+
np.ndarray
|
|
360
|
+
]:
|
|
361
|
+
"""
|
|
362
|
+
Runs for 1 epoch.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
model: tf.keras.Model
|
|
366
|
+
loader: tf.data.Dataset
|
|
367
|
+
weighted: bool, whether to use sample weights in calculating loss and
|
|
368
|
+
accuracy
|
|
369
|
+
top: tuple of int, list of values of k for calculating top-K accuracy
|
|
370
|
+
loss_fn: optional loss function, calculates the mean loss over a batch
|
|
371
|
+
weight_decay: float, L2-regularization constant
|
|
372
|
+
finetune: bool, if true sets model's dropout and BN layers to eval mode
|
|
373
|
+
optimizer: optional optimizer
|
|
374
|
+
|
|
375
|
+
Returns:
|
|
376
|
+
metrics: dict, metrics from epoch, contains keys:
|
|
377
|
+
'loss': float, mean per-example loss over entire epoch,
|
|
378
|
+
only included if loss_fn is not None
|
|
379
|
+
'acc_top{k}': float, accuracy@k over the entire epoch
|
|
380
|
+
heaps: dict, keys are ['tp', 'fp', 'fn'], values are heap_dicts,
|
|
381
|
+
each heap_dict maps label_id (int) to a heap of <= 5 HeapItems with
|
|
382
|
+
data attribute (img, target, top3_conf, top3_preds, img_file)
|
|
383
|
+
- 'tp': priority is the difference between target confidence and
|
|
384
|
+
2nd highest confidence
|
|
385
|
+
- 'fp': priority is the difference between highest confidence and
|
|
386
|
+
target confidence
|
|
387
|
+
- 'fn': same as 'fp'
|
|
388
|
+
confusion_matrix: np.ndarray, shape [num_classes, num_classes],
|
|
389
|
+
C[i, j] = # of samples with true label i, predicted as label j
|
|
390
|
+
"""
|
|
391
|
+
# if evaluating or finetuning, set dropout & BN layers to eval mode
|
|
392
|
+
is_train = False
|
|
393
|
+
train_dropout_and_bn = False
|
|
394
|
+
|
|
395
|
+
if optimizer is not None:
|
|
396
|
+
assert loss_fn is not None
|
|
397
|
+
is_train = True
|
|
398
|
+
|
|
399
|
+
if not finetune:
|
|
400
|
+
train_dropout_and_bn = True
|
|
401
|
+
reg_vars = [
|
|
402
|
+
v for v in model.trainable_variables if 'kernel' in v.name]
|
|
403
|
+
|
|
404
|
+
if loss_fn is not None:
|
|
405
|
+
losses = tf.keras.metrics.Mean()
|
|
406
|
+
accuracies_topk = {
|
|
407
|
+
k: tf.keras.metrics.SparseTopKCategoricalAccuracy(k) for k in top
|
|
408
|
+
}
|
|
409
|
+
|
|
410
|
+
# for each label, track 5 most-confident and least-confident examples
|
|
411
|
+
tp_heaps: dict[int, list[HeapItem]] = defaultdict(list)
|
|
412
|
+
fp_heaps: dict[int, list[HeapItem]] = defaultdict(list)
|
|
413
|
+
fn_heaps: dict[int, list[HeapItem]] = defaultdict(list)
|
|
414
|
+
|
|
415
|
+
all_labels = []
|
|
416
|
+
all_preds = []
|
|
417
|
+
|
|
418
|
+
tqdm_loader = tqdm.tqdm(loader)
|
|
419
|
+
for batch in tqdm_loader:
|
|
420
|
+
if weighted:
|
|
421
|
+
inputs, labels, img_files, weights = batch
|
|
422
|
+
else:
|
|
423
|
+
# even if batch contains sample weights, don't use them
|
|
424
|
+
inputs, labels, img_files = batch[0:3]
|
|
425
|
+
weights = None
|
|
426
|
+
|
|
427
|
+
all_labels.append(labels.numpy())
|
|
428
|
+
desc = []
|
|
429
|
+
with tf.GradientTape(watch_accessed_variables=is_train) as tape:
|
|
430
|
+
outputs = model(inputs, training=train_dropout_and_bn)
|
|
431
|
+
if loss_fn is not None:
|
|
432
|
+
loss = loss_fn(labels, outputs)
|
|
433
|
+
if weights is not None:
|
|
434
|
+
loss *= weights
|
|
435
|
+
# we do not track L2-regularization loss in the loss metric
|
|
436
|
+
losses.update_state(loss, sample_weight=weights)
|
|
437
|
+
desc.append(f'Loss {losses.result().numpy():.4f}')
|
|
438
|
+
|
|
439
|
+
if optimizer is not None:
|
|
440
|
+
loss = tf.math.reduce_mean(loss)
|
|
441
|
+
if not finetune: # only regularize layers before the final FC
|
|
442
|
+
loss += weight_decay * tf.add_n(
|
|
443
|
+
tf.nn.l2_loss(v) for v in reg_vars)
|
|
444
|
+
|
|
445
|
+
all_preds.append(tf.math.argmax(outputs, axis=1).numpy())
|
|
446
|
+
|
|
447
|
+
if optimizer is not None:
|
|
448
|
+
gradients = tape.gradient(loss, model.trainable_variables)
|
|
449
|
+
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
|
|
450
|
+
|
|
451
|
+
for k, acc in accuracies_topk.items():
|
|
452
|
+
acc.update_state(labels, outputs, sample_weight=weights)
|
|
453
|
+
desc.append(f'Acc@{k} {acc.result().numpy() * 100:.3f}')
|
|
454
|
+
tqdm_loader.set_description(' '.join(desc))
|
|
455
|
+
|
|
456
|
+
if return_extreme_images:
|
|
457
|
+
track_extreme_examples(tp_heaps, fp_heaps, fn_heaps, inputs,
|
|
458
|
+
labels, img_files, outputs)
|
|
459
|
+
|
|
460
|
+
confusion_matrix = sklearn.metrics.confusion_matrix(
|
|
461
|
+
y_true=np.concatenate(all_labels), y_pred=np.concatenate(all_preds))
|
|
462
|
+
|
|
463
|
+
metrics = {}
|
|
464
|
+
if loss_fn is not None:
|
|
465
|
+
metrics['loss'] = losses.result().numpy().item()
|
|
466
|
+
for k, acc in accuracies_topk.items():
|
|
467
|
+
metrics[f'acc_top{k}'] = acc.result().numpy().item() * 100
|
|
468
|
+
heaps = {'tp': tp_heaps, 'fp': fp_heaps, 'fn': fn_heaps}
|
|
469
|
+
return metrics, heaps, confusion_matrix
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
def log_run(split: str, epoch: int, writer: tf.summary.SummaryWriter,
|
|
473
|
+
label_names: Sequence[str], metrics: MutableMapping[str, float],
|
|
474
|
+
heaps: Mapping[str, Mapping[int, list[HeapItem]]], cm: np.ndarray
|
|
475
|
+
) -> None:
|
|
476
|
+
"""
|
|
477
|
+
Logs the outputs (metrics, confusion matrix, tp/fp/fn images) from a
|
|
478
|
+
single epoch run to Tensorboard.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
metrics: dict, keys already prefixed with {split}/
|
|
482
|
+
"""
|
|
483
|
+
|
|
484
|
+
per_class_recall = recall_from_confusion_matrix(cm, label_names)
|
|
485
|
+
metrics.update(prefix_all_keys(per_class_recall, f'{split}/label_recall/'))
|
|
486
|
+
|
|
487
|
+
# log metrics
|
|
488
|
+
for metric, value in metrics.items():
|
|
489
|
+
tf.summary.scalar(metric, value, epoch)
|
|
490
|
+
|
|
491
|
+
# log confusion matrix
|
|
492
|
+
cm_fig = plot_utils.plot_confusion_matrix(cm, classes=label_names,
|
|
493
|
+
normalize=True)
|
|
494
|
+
cm_fig_img = tf.convert_to_tensor(fig_to_img(cm_fig)[np.newaxis, ...])
|
|
495
|
+
tf.summary.image(f'confusion_matrix/{split}', cm_fig_img, step=epoch)
|
|
496
|
+
|
|
497
|
+
# log tp/fp/fn images
|
|
498
|
+
for heap_type, heap_dict in heaps.items():
|
|
499
|
+
log_images_with_confidence(heap_dict, label_names, epoch=epoch,
|
|
500
|
+
tag=f'{split}/{heap_type}')
|
|
501
|
+
writer.flush()
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
#%% Main function
|
|
505
|
+
|
|
506
|
+
def main(dataset_dir: str,
|
|
507
|
+
cropped_images_dir: str,
|
|
508
|
+
multilabel: bool,
|
|
509
|
+
model_name: str,
|
|
510
|
+
pretrained: bool,
|
|
511
|
+
finetune: int,
|
|
512
|
+
label_weighted: bool,
|
|
513
|
+
weight_by_detection_conf: bool | str,
|
|
514
|
+
epochs: int,
|
|
515
|
+
batch_size: int,
|
|
516
|
+
lr: float,
|
|
517
|
+
weight_decay: float,
|
|
518
|
+
seed: Optional[int] = None,
|
|
519
|
+
logdir: str = '',
|
|
520
|
+
cache_splits: Sequence[str] = ()) -> None:
|
|
521
|
+
|
|
522
|
+
# input validation
|
|
523
|
+
assert os.path.exists(dataset_dir)
|
|
524
|
+
assert os.path.exists(cropped_images_dir)
|
|
525
|
+
if isinstance(weight_by_detection_conf, str):
|
|
526
|
+
assert os.path.exists(weight_by_detection_conf)
|
|
527
|
+
|
|
528
|
+
# set seed
|
|
529
|
+
seed = np.random.randint(10_000) if seed is None else seed
|
|
530
|
+
np.random.seed(seed)
|
|
531
|
+
tf.random.set_seed(seed)
|
|
532
|
+
|
|
533
|
+
# create logdir and save params
|
|
534
|
+
params = dict(locals()) # make a copy
|
|
535
|
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') # '20200722_110816'
|
|
536
|
+
logdir = os.path.join(logdir, timestamp)
|
|
537
|
+
os.makedirs(logdir, exist_ok=True)
|
|
538
|
+
print('Created logdir:', logdir)
|
|
539
|
+
with open(os.path.join(logdir, 'params.json'), 'w') as f:
|
|
540
|
+
json.dump(params, f, indent=1)
|
|
541
|
+
|
|
542
|
+
gpus = tf.config.experimental.list_physical_devices('GPU')
|
|
543
|
+
for gpu in gpus:
|
|
544
|
+
tf.config.experimental.set_memory_growth(gpu, True)
|
|
545
|
+
|
|
546
|
+
img_size = EFFICIENTNET_MODELS[model_name]['img_size']
|
|
547
|
+
|
|
548
|
+
# create dataloaders and log the index_to_label mapping
|
|
549
|
+
loaders, label_names = create_dataloaders(
|
|
550
|
+
dataset_csv_path=os.path.join(dataset_dir, 'classification_ds.csv'),
|
|
551
|
+
label_index_json_path=os.path.join(dataset_dir, 'label_index.json'),
|
|
552
|
+
splits_json_path=os.path.join(dataset_dir, 'splits.json'),
|
|
553
|
+
cropped_images_dir=cropped_images_dir,
|
|
554
|
+
img_size=img_size,
|
|
555
|
+
multilabel=multilabel,
|
|
556
|
+
label_weighted=label_weighted,
|
|
557
|
+
weight_by_detection_conf=weight_by_detection_conf,
|
|
558
|
+
batch_size=batch_size,
|
|
559
|
+
augment_train=True,
|
|
560
|
+
cache_splits=cache_splits)
|
|
561
|
+
|
|
562
|
+
writer = tf.summary.create_file_writer(logdir)
|
|
563
|
+
writer.set_as_default()
|
|
564
|
+
|
|
565
|
+
model = build_model(
|
|
566
|
+
model_name, num_classes=len(label_names), img_size=img_size,
|
|
567
|
+
pretrained=pretrained, finetune=finetune > 0)
|
|
568
|
+
|
|
569
|
+
# define loss function and optimizer
|
|
570
|
+
loss_fn: tf.keras.losses.Loss
|
|
571
|
+
if multilabel:
|
|
572
|
+
loss_fn = tf.keras.losses.BinaryCrossentropy(
|
|
573
|
+
from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
|
|
574
|
+
else:
|
|
575
|
+
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
|
576
|
+
from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
|
|
577
|
+
|
|
578
|
+
# using EfficientNet training defaults
|
|
579
|
+
# - batch norm momentum: 0.99
|
|
580
|
+
# - optimizer: RMSProp, decay 0.9 and momentum 0.9
|
|
581
|
+
# - epochs: 350
|
|
582
|
+
# - learning rate: 0.256, decays by 0.97 every 2.4 epochs
|
|
583
|
+
# - weight decay: 1e-5
|
|
584
|
+
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
|
|
585
|
+
lr, decay_steps=1, decay_rate=0.97, staircase=True)
|
|
586
|
+
optimizer = tf.keras.optimizers.RMSprop(
|
|
587
|
+
learning_rate=lr, rho=0.9, momentum=0.9)
|
|
588
|
+
|
|
589
|
+
best_epoch_metrics: dict[str, float] = {}
|
|
590
|
+
for epoch in range(epochs):
|
|
591
|
+
print(f'Epoch: {epoch}')
|
|
592
|
+
optimizer.learning_rate = lr_schedule(epoch)
|
|
593
|
+
tf.summary.scalar('lr', optimizer.learning_rate, epoch)
|
|
594
|
+
|
|
595
|
+
if epoch > 0 and finetune == epoch:
|
|
596
|
+
print('Turning off fine-tune!')
|
|
597
|
+
model.base_model.trainable = True
|
|
598
|
+
|
|
599
|
+
print('- train:')
|
|
600
|
+
|
|
601
|
+
train_metrics, train_heaps, train_cm = run_epoch(
|
|
602
|
+
model, loader=loaders['train'], weighted=label_weighted,
|
|
603
|
+
loss_fn=loss_fn, weight_decay=weight_decay, optimizer=optimizer,
|
|
604
|
+
finetune=finetune > epoch, return_extreme_images=True)
|
|
605
|
+
train_metrics = prefix_all_keys(train_metrics, prefix='train/')
|
|
606
|
+
log_run('train', epoch, writer, label_names,
|
|
607
|
+
metrics=train_metrics, heaps=train_heaps, cm=train_cm)
|
|
608
|
+
|
|
609
|
+
print('- val:')
|
|
610
|
+
val_metrics, val_heaps, val_cm = run_epoch(
|
|
611
|
+
model, loader=loaders['val'], weighted=label_weighted,
|
|
612
|
+
loss_fn=loss_fn, return_extreme_images=True)
|
|
613
|
+
val_metrics = prefix_all_keys(val_metrics, prefix='val/')
|
|
614
|
+
log_run('val', epoch, writer, label_names,
|
|
615
|
+
metrics=val_metrics, heaps=val_heaps, cm=val_cm)
|
|
616
|
+
|
|
617
|
+
if val_metrics['val/acc_top1'] > best_epoch_metrics.get('val/acc_top1', 0): # pylint: disable=line-too-long
|
|
618
|
+
filename = os.path.join(logdir, f'ckpt_{epoch}.h5')
|
|
619
|
+
print(f'New best model! Saving checkpoint to {filename}')
|
|
620
|
+
model.save(filename)
|
|
621
|
+
best_epoch_metrics.update(train_metrics)
|
|
622
|
+
best_epoch_metrics.update(val_metrics)
|
|
623
|
+
best_epoch_metrics['epoch'] = epoch
|
|
624
|
+
|
|
625
|
+
print('- test:')
|
|
626
|
+
test_metrics, test_heaps, test_cm = run_epoch(
|
|
627
|
+
model, loader=loaders['test'], weighted=label_weighted,
|
|
628
|
+
loss_fn=loss_fn, return_extreme_images=True)
|
|
629
|
+
test_metrics = prefix_all_keys(test_metrics, prefix='test/')
|
|
630
|
+
log_run('test', epoch, writer, label_names,
|
|
631
|
+
metrics=test_metrics, heaps=test_heaps, cm=test_cm)
|
|
632
|
+
|
|
633
|
+
# stop training after 8 epochs without improvement
|
|
634
|
+
if epoch >= best_epoch_metrics['epoch'] + 8:
|
|
635
|
+
break
|
|
636
|
+
|
|
637
|
+
hparams_dict = {
|
|
638
|
+
'model_name': model_name,
|
|
639
|
+
'multilabel': multilabel,
|
|
640
|
+
'finetune': finetune,
|
|
641
|
+
'batch_size': batch_size,
|
|
642
|
+
'epochs': epochs
|
|
643
|
+
}
|
|
644
|
+
hp.hparams(hparams_dict)
|
|
645
|
+
writer.close()
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
#%% Command-line driver
|
|
649
|
+
|
|
650
|
+
def _parse_args() -> argparse.Namespace:
|
|
651
|
+
"""Parses arguments."""
|
|
652
|
+
parser = argparse.ArgumentParser(
|
|
653
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
654
|
+
description='Trains classifier.')
|
|
655
|
+
parser.add_argument(
|
|
656
|
+
'dataset_dir',
|
|
657
|
+
help='path to directory containing: 1) classification dataset CSV, '
|
|
658
|
+
'2) label index JSON, 3) splits JSON')
|
|
659
|
+
parser.add_argument(
|
|
660
|
+
'cropped_images_dir',
|
|
661
|
+
help='path to local directory where image crops are saved')
|
|
662
|
+
parser.add_argument(
|
|
663
|
+
'--multilabel', action='store_true',
|
|
664
|
+
help='for multi-label, multi-class classification')
|
|
665
|
+
parser.add_argument(
|
|
666
|
+
'-m', '--model-name', default='efficientnet-b0',
|
|
667
|
+
choices=list(EFFICIENTNET_MODELS.keys()),
|
|
668
|
+
help='which EfficientNet model')
|
|
669
|
+
parser.add_argument(
|
|
670
|
+
'--pretrained', action='store_true',
|
|
671
|
+
help='start with pretrained model')
|
|
672
|
+
parser.add_argument(
|
|
673
|
+
'--finetune', type=int, default=0,
|
|
674
|
+
help='only fine tune the final fully-connected layer for the first '
|
|
675
|
+
'<finetune> epochs')
|
|
676
|
+
parser.add_argument(
|
|
677
|
+
'--label-weighted', action='store_true',
|
|
678
|
+
help='weight training samples to balance labels')
|
|
679
|
+
parser.add_argument(
|
|
680
|
+
'--weight-by-detection-conf', nargs='?', const=True, default=False,
|
|
681
|
+
help='weight training examples by detection confidence. '
|
|
682
|
+
'Optionally takes a .npz file for isotonic calibration.')
|
|
683
|
+
parser.add_argument(
|
|
684
|
+
'--epochs', type=int, default=0,
|
|
685
|
+
help='number of epochs for training, 0 for eval-only')
|
|
686
|
+
parser.add_argument(
|
|
687
|
+
'--batch-size', type=int, default=256,
|
|
688
|
+
help='batch size for both training and eval')
|
|
689
|
+
parser.add_argument(
|
|
690
|
+
'--lr', type=float, default=None,
|
|
691
|
+
help='initial learning rate, defaults to (0.016 * batch_size / 256)')
|
|
692
|
+
parser.add_argument(
|
|
693
|
+
'--weight-decay', type=float, default=1e-5,
|
|
694
|
+
help='weight decay')
|
|
695
|
+
parser.add_argument(
|
|
696
|
+
'--seed', type=int,
|
|
697
|
+
help='random seed')
|
|
698
|
+
parser.add_argument(
|
|
699
|
+
'--logdir', default='.',
|
|
700
|
+
help='directory where TensorBoard logs and a params file are saved')
|
|
701
|
+
parser.add_argument(
|
|
702
|
+
'--cache', nargs='*', choices=['train', 'val', 'test'], default=(),
|
|
703
|
+
help='which splits of the dataset to cache')
|
|
704
|
+
return parser.parse_args()
|
|
705
|
+
|
|
706
|
+
|
|
707
|
+
if __name__ == '__main__':
|
|
708
|
+
args = _parse_args()
|
|
709
|
+
if args.lr is None:
|
|
710
|
+
args.lr = 0.016 * args.batch_size / 256 # based on TF models repo
|
|
711
|
+
main(dataset_dir=args.dataset_dir,
|
|
712
|
+
cropped_images_dir=args.cropped_images_dir,
|
|
713
|
+
multilabel=args.multilabel,
|
|
714
|
+
model_name=args.model_name,
|
|
715
|
+
pretrained=args.pretrained,
|
|
716
|
+
finetune=args.finetune,
|
|
717
|
+
label_weighted=args.label_weighted,
|
|
718
|
+
weight_by_detection_conf=args.weight_by_detection_conf,
|
|
719
|
+
epochs=args.epochs,
|
|
720
|
+
batch_size=args.batch_size,
|
|
721
|
+
lr=args.lr,
|
|
722
|
+
weight_decay=args.weight_decay,
|
|
723
|
+
seed=args.seed,
|
|
724
|
+
logdir=args.logdir,
|
|
725
|
+
cache_splits=args.cache)
|