megadetector 10.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (147) hide show
  1. megadetector/__init__.py +0 -0
  2. megadetector/api/__init__.py +0 -0
  3. megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
  4. megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
  5. megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
  6. megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
  7. megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
  8. megadetector/classification/__init__.py +0 -0
  9. megadetector/classification/aggregate_classifier_probs.py +108 -0
  10. megadetector/classification/analyze_failed_images.py +227 -0
  11. megadetector/classification/cache_batchapi_outputs.py +198 -0
  12. megadetector/classification/create_classification_dataset.py +626 -0
  13. megadetector/classification/crop_detections.py +516 -0
  14. megadetector/classification/csv_to_json.py +226 -0
  15. megadetector/classification/detect_and_crop.py +853 -0
  16. megadetector/classification/efficientnet/__init__.py +9 -0
  17. megadetector/classification/efficientnet/model.py +415 -0
  18. megadetector/classification/efficientnet/utils.py +608 -0
  19. megadetector/classification/evaluate_model.py +520 -0
  20. megadetector/classification/identify_mislabeled_candidates.py +152 -0
  21. megadetector/classification/json_to_azcopy_list.py +63 -0
  22. megadetector/classification/json_validator.py +696 -0
  23. megadetector/classification/map_classification_categories.py +276 -0
  24. megadetector/classification/merge_classification_detection_output.py +509 -0
  25. megadetector/classification/prepare_classification_script.py +194 -0
  26. megadetector/classification/prepare_classification_script_mc.py +228 -0
  27. megadetector/classification/run_classifier.py +287 -0
  28. megadetector/classification/save_mislabeled.py +110 -0
  29. megadetector/classification/train_classifier.py +827 -0
  30. megadetector/classification/train_classifier_tf.py +725 -0
  31. megadetector/classification/train_utils.py +323 -0
  32. megadetector/data_management/__init__.py +0 -0
  33. megadetector/data_management/animl_to_md.py +161 -0
  34. megadetector/data_management/annotations/__init__.py +0 -0
  35. megadetector/data_management/annotations/annotation_constants.py +33 -0
  36. megadetector/data_management/camtrap_dp_to_coco.py +270 -0
  37. megadetector/data_management/cct_json_utils.py +566 -0
  38. megadetector/data_management/cct_to_md.py +184 -0
  39. megadetector/data_management/cct_to_wi.py +293 -0
  40. megadetector/data_management/coco_to_labelme.py +284 -0
  41. megadetector/data_management/coco_to_yolo.py +702 -0
  42. megadetector/data_management/databases/__init__.py +0 -0
  43. megadetector/data_management/databases/add_width_and_height_to_db.py +107 -0
  44. megadetector/data_management/databases/combine_coco_camera_traps_files.py +210 -0
  45. megadetector/data_management/databases/integrity_check_json_db.py +528 -0
  46. megadetector/data_management/databases/subset_json_db.py +195 -0
  47. megadetector/data_management/generate_crops_from_cct.py +200 -0
  48. megadetector/data_management/get_image_sizes.py +164 -0
  49. megadetector/data_management/labelme_to_coco.py +559 -0
  50. megadetector/data_management/labelme_to_yolo.py +349 -0
  51. megadetector/data_management/lila/__init__.py +0 -0
  52. megadetector/data_management/lila/create_lila_blank_set.py +556 -0
  53. megadetector/data_management/lila/create_lila_test_set.py +187 -0
  54. megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
  55. megadetector/data_management/lila/download_lila_subset.py +182 -0
  56. megadetector/data_management/lila/generate_lila_per_image_labels.py +777 -0
  57. megadetector/data_management/lila/get_lila_annotation_counts.py +174 -0
  58. megadetector/data_management/lila/get_lila_image_counts.py +112 -0
  59. megadetector/data_management/lila/lila_common.py +319 -0
  60. megadetector/data_management/lila/test_lila_metadata_urls.py +164 -0
  61. megadetector/data_management/mewc_to_md.py +344 -0
  62. megadetector/data_management/ocr_tools.py +873 -0
  63. megadetector/data_management/read_exif.py +964 -0
  64. megadetector/data_management/remap_coco_categories.py +195 -0
  65. megadetector/data_management/remove_exif.py +156 -0
  66. megadetector/data_management/rename_images.py +194 -0
  67. megadetector/data_management/resize_coco_dataset.py +663 -0
  68. megadetector/data_management/speciesnet_to_md.py +41 -0
  69. megadetector/data_management/wi_download_csv_to_coco.py +247 -0
  70. megadetector/data_management/yolo_output_to_md_output.py +594 -0
  71. megadetector/data_management/yolo_to_coco.py +876 -0
  72. megadetector/data_management/zamba_to_md.py +188 -0
  73. megadetector/detection/__init__.py +0 -0
  74. megadetector/detection/change_detection.py +840 -0
  75. megadetector/detection/process_video.py +479 -0
  76. megadetector/detection/pytorch_detector.py +1451 -0
  77. megadetector/detection/run_detector.py +1267 -0
  78. megadetector/detection/run_detector_batch.py +2159 -0
  79. megadetector/detection/run_inference_with_yolov5_val.py +1314 -0
  80. megadetector/detection/run_md_and_speciesnet.py +1494 -0
  81. megadetector/detection/run_tiled_inference.py +1038 -0
  82. megadetector/detection/tf_detector.py +209 -0
  83. megadetector/detection/video_utils.py +1379 -0
  84. megadetector/postprocessing/__init__.py +0 -0
  85. megadetector/postprocessing/add_max_conf.py +72 -0
  86. megadetector/postprocessing/categorize_detections_by_size.py +166 -0
  87. megadetector/postprocessing/classification_postprocessing.py +1752 -0
  88. megadetector/postprocessing/combine_batch_outputs.py +249 -0
  89. megadetector/postprocessing/compare_batch_results.py +2110 -0
  90. megadetector/postprocessing/convert_output_format.py +403 -0
  91. megadetector/postprocessing/create_crop_folder.py +629 -0
  92. megadetector/postprocessing/detector_calibration.py +570 -0
  93. megadetector/postprocessing/generate_csv_report.py +522 -0
  94. megadetector/postprocessing/load_api_results.py +223 -0
  95. megadetector/postprocessing/md_to_coco.py +428 -0
  96. megadetector/postprocessing/md_to_labelme.py +351 -0
  97. megadetector/postprocessing/md_to_wi.py +41 -0
  98. megadetector/postprocessing/merge_detections.py +392 -0
  99. megadetector/postprocessing/postprocess_batch_results.py +2077 -0
  100. megadetector/postprocessing/remap_detection_categories.py +226 -0
  101. megadetector/postprocessing/render_detection_confusion_matrix.py +677 -0
  102. megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +206 -0
  103. megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +82 -0
  104. megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1665 -0
  105. megadetector/postprocessing/separate_detections_into_folders.py +795 -0
  106. megadetector/postprocessing/subset_json_detector_output.py +964 -0
  107. megadetector/postprocessing/top_folders_to_bottom.py +238 -0
  108. megadetector/postprocessing/validate_batch_results.py +332 -0
  109. megadetector/taxonomy_mapping/__init__.py +0 -0
  110. megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
  111. megadetector/taxonomy_mapping/map_new_lila_datasets.py +213 -0
  112. megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +165 -0
  113. megadetector/taxonomy_mapping/preview_lila_taxonomy.py +543 -0
  114. megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
  115. megadetector/taxonomy_mapping/simple_image_download.py +224 -0
  116. megadetector/taxonomy_mapping/species_lookup.py +1008 -0
  117. megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
  118. megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
  119. megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
  120. megadetector/tests/__init__.py +0 -0
  121. megadetector/tests/test_nms_synthetic.py +335 -0
  122. megadetector/utils/__init__.py +0 -0
  123. megadetector/utils/ct_utils.py +1857 -0
  124. megadetector/utils/directory_listing.py +199 -0
  125. megadetector/utils/extract_frames_from_video.py +307 -0
  126. megadetector/utils/gpu_test.py +125 -0
  127. megadetector/utils/md_tests.py +2072 -0
  128. megadetector/utils/path_utils.py +2832 -0
  129. megadetector/utils/process_utils.py +172 -0
  130. megadetector/utils/split_locations_into_train_val.py +237 -0
  131. megadetector/utils/string_utils.py +234 -0
  132. megadetector/utils/url_utils.py +825 -0
  133. megadetector/utils/wi_platform_utils.py +968 -0
  134. megadetector/utils/wi_taxonomy_utils.py +1759 -0
  135. megadetector/utils/write_html_image_list.py +239 -0
  136. megadetector/visualization/__init__.py +0 -0
  137. megadetector/visualization/plot_utils.py +309 -0
  138. megadetector/visualization/render_images_with_thumbnails.py +243 -0
  139. megadetector/visualization/visualization_utils.py +1940 -0
  140. megadetector/visualization/visualize_db.py +630 -0
  141. megadetector/visualization/visualize_detector_output.py +479 -0
  142. megadetector/visualization/visualize_video_output.py +705 -0
  143. megadetector-10.0.13.dist-info/METADATA +134 -0
  144. megadetector-10.0.13.dist-info/RECORD +147 -0
  145. megadetector-10.0.13.dist-info/WHEEL +5 -0
  146. megadetector-10.0.13.dist-info/licenses/LICENSE +19 -0
  147. megadetector-10.0.13.dist-info/top_level.txt +1 -0
@@ -0,0 +1,725 @@
1
+ """
2
+
3
+ train_classifier_tf.py
4
+
5
+ Train an EfficientNet classifier.
6
+
7
+ Currently the implementation of multi-label multi-class classification is
8
+ non-functional.
9
+
10
+ During training, start tensorboard from within the classification/ directory:
11
+ tensorboard --logdir run --bind_all --samples_per_plugin scalars=0,images=0
12
+
13
+ """
14
+
15
+ #%% Imports and constants
16
+
17
+ from __future__ import annotations
18
+
19
+ import os
20
+ import json
21
+ import uuid
22
+ import argparse
23
+
24
+ import tqdm
25
+ import numpy as np
26
+ import sklearn.metrics
27
+ import tensorflow as tf
28
+
29
+ from collections import defaultdict
30
+ from collections.abc import Callable, Mapping, MutableMapping, Sequence
31
+ from datetime import datetime
32
+ from typing import Any, Optional
33
+ from tensorboard.plugins.hparams import api as hp
34
+
35
+ from megadetector.classification.train_utils import (
36
+ HeapItem, recall_from_confusion_matrix, add_to_heap, fig_to_img,
37
+ imgs_with_confidences, load_dataset_csv, prefix_all_keys)
38
+ from megadetector.visualization import plot_utils
39
+
40
+ AUTOTUNE = tf.data.experimental.AUTOTUNE
41
+
42
+ # match pytorch EfficientNet model names
43
+ EFFICIENTNET_MODELS: Mapping[str, Mapping[str, Any]] = {
44
+ 'efficientnet-b0': dict(cls='EfficientNetB0', img_size=224, dropout=0.2),
45
+ 'efficientnet-b1': dict(cls='EfficientNetB1', img_size=240, dropout=0.2),
46
+ 'efficientnet-b2': dict(cls='EfficientNetB2', img_size=260, dropout=0.3),
47
+ 'efficientnet-b3': dict(cls='EfficientNetB3', img_size=300, dropout=0.3),
48
+ 'efficientnet-b4': dict(cls='EfficientNetB4', img_size=380, dropout=0.4),
49
+ 'efficientnet-b5': dict(cls='EfficientNetB5', img_size=456, dropout=0.4),
50
+ 'efficientnet-b6': dict(cls='EfficientNetB6', img_size=528, dropout=0.5),
51
+ 'efficientnet-b7': dict(cls='EfficientNetB7', img_size=600, dropout=0.5)
52
+ }
53
+
54
+
55
+ #%% Example usage
56
+
57
+ """
58
+ python train_classifier_tf.py run_idfg /ssd/crops_sq \
59
+ -m "efficientnet-b0" --pretrained --finetune --label-weighted \
60
+ --epochs 50 --batch-size 512 --lr 1e-4 \
61
+ --seed 123 \
62
+ --logdir run_idfg
63
+ """
64
+
65
+
66
+ #%% Support functions
67
+
68
+ def create_dataset(
69
+ img_files: Sequence[str],
70
+ labels: Sequence[Any],
71
+ sample_weights: Optional[Sequence[float]] = None,
72
+ img_base_dir: str = '',
73
+ transform: Optional[Callable[[tf.Tensor], Any]] = None,
74
+ target_transform: Optional[Callable[[Any], Any]] = None,
75
+ cache: bool | str = False
76
+ ) -> tf.data.Dataset:
77
+ """
78
+ Create a tf.data.Dataset.
79
+
80
+ The dataset returns elements (img, label, img_file, sample_weight) if
81
+ sample_weights is not None, or (img, label, img_file) if
82
+ sample_weights=None.
83
+ img: tf.Tensor, shape [H, W, 3], type uint8
84
+ label: tf.Tensor
85
+ img_file: tf.Tensor, scalar, type str
86
+ sample_weight: tf.Tensor, scalar, type float32
87
+
88
+ Args:
89
+ img_files: list of str, relative paths from img_base_dir
90
+ labels: list of int if multilabel=False
91
+ sample_weights: optional list of float
92
+ img_base_dir: str, base directory for images
93
+ transform: optional transform to apply to a single uint8 JPEG image
94
+ target_transform: optional transform to apply to a single label
95
+ cache: bool or str, cache images in memory if True, cache images to
96
+ a file on disk if a str
97
+
98
+ Returns: tf.data.Dataset
99
+ """
100
+
101
+ # images dataset
102
+ img_ds = tf.data.Dataset.from_tensor_slices(img_files)
103
+ img_ds = img_ds.map(lambda p: tf.io.read_file(img_base_dir + os.sep + p),
104
+ num_parallel_calls=AUTOTUNE)
105
+
106
+ # for smaller disk / memory usage, we cache the raw JPEG bytes instead
107
+ # of the decoded Tensor
108
+ if isinstance(cache, str):
109
+ img_ds = img_ds.cache(cache)
110
+ elif cache:
111
+ img_ds = img_ds.cache()
112
+
113
+ # convert JPEG bytes to a 3D uint8 Tensor
114
+ # keras EfficientNet already includes normalization from [0, 255] to [0, 1],
115
+ # so we don't need to do that here
116
+ img_ds = img_ds.map(lambda img: tf.io.decode_jpeg(img, channels=3))
117
+
118
+ if transform:
119
+ img_ds = img_ds.map(transform, num_parallel_calls=AUTOTUNE)
120
+
121
+ # labels dataset
122
+ labels_ds = tf.data.Dataset.from_tensor_slices(labels)
123
+ if target_transform:
124
+ labels_ds = labels_ds.map(target_transform, num_parallel_calls=AUTOTUNE)
125
+
126
+ # img_files dataset
127
+ img_files_ds = tf.data.Dataset.from_tensor_slices(img_files)
128
+
129
+ if sample_weights is None:
130
+ return tf.data.Dataset.zip((img_ds, labels_ds, img_files_ds))
131
+
132
+ # weights dataset
133
+ weights_ds = tf.data.Dataset.from_tensor_slices(sample_weights)
134
+ return tf.data.Dataset.zip((img_ds, labels_ds, img_files_ds, weights_ds))
135
+
136
+
137
+ def create_dataloaders(
138
+ dataset_csv_path: str,
139
+ label_index_json_path: str,
140
+ splits_json_path: str,
141
+ cropped_images_dir: str,
142
+ img_size: int,
143
+ multilabel: bool,
144
+ label_weighted: bool,
145
+ weight_by_detection_conf: bool | str,
146
+ batch_size: int,
147
+ augment_train: bool,
148
+ cache_splits: Sequence[str]
149
+ ) -> tuple[dict[str, tf.data.Dataset], list[str]]:
150
+ """
151
+ Args:
152
+ dataset_csv_path: str, path to CSV file with columns
153
+ ['dataset', 'location', 'label'], where label is a comma-delimited
154
+ list of labels
155
+ splits_json_path: str, path to JSON file
156
+ augment_train: bool, whether to shuffle/augment the training set
157
+ cache_splits: list of str, splits to cache
158
+ training set is cached at /mnt/tempds/random_file_name
159
+ validation and test sets are cached in memory
160
+
161
+ Returns:
162
+ datasets: dict, maps split to DataLoader
163
+ label_names: list of str, label names in order of label id
164
+ """
165
+
166
+ df, label_names, split_to_locs = load_dataset_csv(
167
+ dataset_csv_path, label_index_json_path, splits_json_path,
168
+ multilabel=multilabel, label_weighted=label_weighted,
169
+ weight_by_detection_conf=weight_by_detection_conf)
170
+
171
+ # define the transforms
172
+
173
+ # efficientnet data preprocessing:
174
+ # - train:
175
+ # 1) random crop: aspect_ratio_range=(0.75, 1.33), area_range=(0.08, 1.0)
176
+ # 2) bicubic resize to img_size
177
+ # 3) random horizontal flip
178
+ # - test:
179
+ # 1) center crop
180
+ # 2) bicubic resize to img_size
181
+
182
+ @tf.function
183
+ def train_transform(img: tf.Tensor) -> tf.Tensor:
184
+ """Returns: tf.Tensor, shape [img_size, img_size, C], type float32"""
185
+ img = tf.image.resize_with_pad(img, img_size, img_size,
186
+ method=tf.image.ResizeMethod.BICUBIC)
187
+ img = tf.image.random_flip_left_right(img)
188
+ img = tf.image.random_brightness(img, max_delta=0.25)
189
+ img = tf.image.random_contrast(img, lower=0.75, upper=1.25)
190
+ img = tf.image.random_saturation(img, lower=0.75, upper=1.25)
191
+ return img
192
+
193
+ @tf.function
194
+ def test_transform(img: tf.Tensor) -> tf.Tensor:
195
+ """Returns: tf.Tensor, shape [img_size, img_size, C], type float32"""
196
+ img = tf.image.resize_with_pad(img, img_size, img_size,
197
+ method=tf.image.ResizeMethod.BICUBIC)
198
+ return img
199
+
200
+ dataloaders = {}
201
+ for split, locs in split_to_locs.items():
202
+ is_train = (split == 'train') and augment_train
203
+ split_df = df[df['dataset_location'].isin(locs)]
204
+
205
+ weights = None
206
+ if label_weighted or weight_by_detection_conf:
207
+ # weights sums to:
208
+ # - if weight_by_detection_conf: (# images in split - conf delta)
209
+ # - otherwise: (# images in split)
210
+ weights = split_df['weights'].tolist()
211
+ if not weight_by_detection_conf:
212
+ assert np.isclose(sum(weights), len(split_df))
213
+
214
+ cache: bool | str = (split in cache_splits)
215
+ if split == 'train' and 'train' in cache_splits:
216
+ unique_filename = str(uuid.uuid4())
217
+ os.makedirs('/mnt/tempds/', exist_ok=True)
218
+ cache = f'/mnt/tempds/{unique_filename}'
219
+
220
+ ds = create_dataset(
221
+ img_files=split_df['path'].tolist(),
222
+ labels=split_df['label_index'].tolist(),
223
+ sample_weights=weights,
224
+ img_base_dir=cropped_images_dir,
225
+ transform=train_transform if is_train else test_transform,
226
+ target_transform=None,
227
+ cache=cache)
228
+ if is_train:
229
+ ds = ds.shuffle(1000, reshuffle_each_iteration=True)
230
+ ds = ds.batch(batch_size).prefetch(buffer_size=AUTOTUNE)
231
+ dataloaders[split] = ds
232
+
233
+ return dataloaders, label_names
234
+
235
+
236
+ def build_model(model_name: str, num_classes: int, img_size: int,
237
+ pretrained: bool, finetune: bool) -> tf.keras.Model:
238
+ """
239
+ Creates a model with an EfficientNet base.
240
+ """
241
+
242
+ class_name = EFFICIENTNET_MODELS[model_name]['cls']
243
+ dropout = EFFICIENTNET_MODELS[model_name]['dropout']
244
+
245
+ model_class = tf.keras.applications.__dict__[class_name]
246
+ weights = 'imagenet' if pretrained else None
247
+ inputs = tf.keras.layers.Input(shape=(img_size, img_size, 3))
248
+ base_model = model_class(
249
+ input_tensor=inputs, weights=weights, include_top=False, pooling='avg')
250
+
251
+ if finetune:
252
+ # freeze the base model's weights, including BatchNorm statistics
253
+ # https://www.tensorflow.org/guide/keras/transfer_learning#fine-tuning
254
+ base_model.trainable = False
255
+
256
+ # rebuild output
257
+ x = tf.keras.layers.Dropout(dropout, name='top_dropout')(base_model.output)
258
+ outputs = tf.keras.layers.Dense(
259
+ num_classes,
260
+ kernel_initializer=tf.keras.initializers.VarianceScaling(
261
+ scale=1. / 3., mode='fan_out', distribution='uniform'),
262
+ name='logits')(x)
263
+ model = tf.keras.Model(inputs, outputs, name='complete_model')
264
+ model.base_model = base_model # cache this so that we can turn off finetune
265
+ return model
266
+
267
+
268
+ def log_images_with_confidence(
269
+ heap_dict: Mapping[int, list[HeapItem]],
270
+ label_names: Sequence[str],
271
+ epoch: int,
272
+ tag: str) -> None:
273
+ """
274
+ Args:
275
+ heap_dict: dict, maps label_id to list of HeapItem, where each HeapItem
276
+ data is a list [img, target, top3_conf, top3_preds, img_file],
277
+ and img is a tf.Tensor of shape [H, W, 3]
278
+ label_names: list of str, label names in order of label id
279
+ epoch: int
280
+ tag: str
281
+ """
282
+
283
+ for label_id, heap in heap_dict.items():
284
+ label_name = label_names[label_id]
285
+
286
+ sorted_heap = sorted(heap, reverse=True) # sort largest to smallest
287
+ imgs_list = [item.data for item in sorted_heap]
288
+ fig, img_files = imgs_with_confidences(imgs_list, label_names)
289
+
290
+ # tf.summary.image requires input of shape [N, H, W, C]
291
+ fig_img = tf.convert_to_tensor(fig_to_img(fig)[np.newaxis, ...])
292
+ tf.summary.image(f'{label_name}/{tag}', fig_img, step=epoch)
293
+ tf.summary.text(f'{label_name}/{tag}_files', '\n\n'.join(img_files),
294
+ step=epoch)
295
+
296
+
297
+ def track_extreme_examples(tp_heaps: dict[int, list[HeapItem]],
298
+ fp_heaps: dict[int, list[HeapItem]],
299
+ fn_heaps: dict[int, list[HeapItem]],
300
+ inputs: tf.Tensor,
301
+ labels: tf.Tensor,
302
+ img_files: tf.Tensor,
303
+ logits: tf.Tensor) -> None:
304
+ """
305
+ Updates the 5 most extreme true-positive (tp), false-positive (fp), and
306
+ false-negative (fn) examples with examples from this batch.
307
+
308
+ Each HeapItem's data attribute is a tuple with:
309
+ - img: np.ndarray, shape [H, W, 3], type uint8
310
+ - label: int
311
+ - top3_conf: list of float
312
+ - top3_preds: list of float
313
+ - img_file: str
314
+
315
+ Args:
316
+ *_heaps: dict, maps label_id (int) to heap of HeapItems
317
+ inputs: tf.Tensor, shape [batch_size, H, W, 3], type float32
318
+ labels: tf.Tensor, shape [batch_size]
319
+ img_files: tf.Tensor, shape [batch_size], type tf.string
320
+ logits: tf.Tensor, shape [batch_size, num_classes]
321
+ """
322
+
323
+ labels = labels.numpy().tolist()
324
+ inputs = inputs.numpy().astype(np.uint8)
325
+ img_files = img_files.numpy().astype(str).tolist()
326
+ batch_probs = tf.nn.softmax(logits, axis=1)
327
+ iterable = zip(labels, inputs, img_files, batch_probs)
328
+ for label, img, img_file, confs in iterable:
329
+ label_conf = confs[label].numpy().item()
330
+
331
+ top3_conf, top3_preds = tf.math.top_k(confs, k=3, sorted=True)
332
+ top3_conf = top3_conf.numpy().tolist()
333
+ top3_preds = top3_preds.numpy().tolist()
334
+
335
+ data = (img, label, top3_conf, top3_preds, img_file)
336
+ if top3_preds[0] == label: # true positive
337
+ item = HeapItem(priority=label_conf - top3_conf[1], data=data)
338
+ add_to_heap(tp_heaps[label], item, k=5)
339
+ else:
340
+ # false positive for top3_pred[0]
341
+ # false negative for label
342
+ item = HeapItem(priority=top3_conf[0] - label_conf, data=data)
343
+ add_to_heap(fp_heaps[top3_preds[0]], item, k=5)
344
+ add_to_heap(fn_heaps[label], item, k=5)
345
+
346
+
347
+ def run_epoch(model: tf.keras.Model,
348
+ loader: tf.data.Dataset,
349
+ weighted: bool,
350
+ top: Sequence[int] = (1, 3),
351
+ loss_fn: Optional[tf.keras.losses.Loss] = None,
352
+ weight_decay: float = 0,
353
+ finetune: bool = False,
354
+ optimizer: Optional[tf.keras.optimizers.Optimizer] = None,
355
+ return_extreme_images: bool = False
356
+ ) -> tuple[
357
+ dict[str, float],
358
+ dict[str, dict[int, list[HeapItem]]],
359
+ np.ndarray
360
+ ]:
361
+ """
362
+ Runs for 1 epoch.
363
+
364
+ Args:
365
+ model: tf.keras.Model
366
+ loader: tf.data.Dataset
367
+ weighted: bool, whether to use sample weights in calculating loss and
368
+ accuracy
369
+ top: tuple of int, list of values of k for calculating top-K accuracy
370
+ loss_fn: optional loss function, calculates the mean loss over a batch
371
+ weight_decay: float, L2-regularization constant
372
+ finetune: bool, if true sets model's dropout and BN layers to eval mode
373
+ optimizer: optional optimizer
374
+
375
+ Returns:
376
+ metrics: dict, metrics from epoch, contains keys:
377
+ 'loss': float, mean per-example loss over entire epoch,
378
+ only included if loss_fn is not None
379
+ 'acc_top{k}': float, accuracy@k over the entire epoch
380
+ heaps: dict, keys are ['tp', 'fp', 'fn'], values are heap_dicts,
381
+ each heap_dict maps label_id (int) to a heap of <= 5 HeapItems with
382
+ data attribute (img, target, top3_conf, top3_preds, img_file)
383
+ - 'tp': priority is the difference between target confidence and
384
+ 2nd highest confidence
385
+ - 'fp': priority is the difference between highest confidence and
386
+ target confidence
387
+ - 'fn': same as 'fp'
388
+ confusion_matrix: np.ndarray, shape [num_classes, num_classes],
389
+ C[i, j] = # of samples with true label i, predicted as label j
390
+ """
391
+ # if evaluating or finetuning, set dropout & BN layers to eval mode
392
+ is_train = False
393
+ train_dropout_and_bn = False
394
+
395
+ if optimizer is not None:
396
+ assert loss_fn is not None
397
+ is_train = True
398
+
399
+ if not finetune:
400
+ train_dropout_and_bn = True
401
+ reg_vars = [
402
+ v for v in model.trainable_variables if 'kernel' in v.name]
403
+
404
+ if loss_fn is not None:
405
+ losses = tf.keras.metrics.Mean()
406
+ accuracies_topk = {
407
+ k: tf.keras.metrics.SparseTopKCategoricalAccuracy(k) for k in top
408
+ }
409
+
410
+ # for each label, track 5 most-confident and least-confident examples
411
+ tp_heaps: dict[int, list[HeapItem]] = defaultdict(list)
412
+ fp_heaps: dict[int, list[HeapItem]] = defaultdict(list)
413
+ fn_heaps: dict[int, list[HeapItem]] = defaultdict(list)
414
+
415
+ all_labels = []
416
+ all_preds = []
417
+
418
+ tqdm_loader = tqdm.tqdm(loader)
419
+ for batch in tqdm_loader:
420
+ if weighted:
421
+ inputs, labels, img_files, weights = batch
422
+ else:
423
+ # even if batch contains sample weights, don't use them
424
+ inputs, labels, img_files = batch[0:3]
425
+ weights = None
426
+
427
+ all_labels.append(labels.numpy())
428
+ desc = []
429
+ with tf.GradientTape(watch_accessed_variables=is_train) as tape:
430
+ outputs = model(inputs, training=train_dropout_and_bn)
431
+ if loss_fn is not None:
432
+ loss = loss_fn(labels, outputs)
433
+ if weights is not None:
434
+ loss *= weights
435
+ # we do not track L2-regularization loss in the loss metric
436
+ losses.update_state(loss, sample_weight=weights)
437
+ desc.append(f'Loss {losses.result().numpy():.4f}')
438
+
439
+ if optimizer is not None:
440
+ loss = tf.math.reduce_mean(loss)
441
+ if not finetune: # only regularize layers before the final FC
442
+ loss += weight_decay * tf.add_n(
443
+ tf.nn.l2_loss(v) for v in reg_vars)
444
+
445
+ all_preds.append(tf.math.argmax(outputs, axis=1).numpy())
446
+
447
+ if optimizer is not None:
448
+ gradients = tape.gradient(loss, model.trainable_variables)
449
+ optimizer.apply_gradients(zip(gradients, model.trainable_variables))
450
+
451
+ for k, acc in accuracies_topk.items():
452
+ acc.update_state(labels, outputs, sample_weight=weights)
453
+ desc.append(f'Acc@{k} {acc.result().numpy() * 100:.3f}')
454
+ tqdm_loader.set_description(' '.join(desc))
455
+
456
+ if return_extreme_images:
457
+ track_extreme_examples(tp_heaps, fp_heaps, fn_heaps, inputs,
458
+ labels, img_files, outputs)
459
+
460
+ confusion_matrix = sklearn.metrics.confusion_matrix(
461
+ y_true=np.concatenate(all_labels), y_pred=np.concatenate(all_preds))
462
+
463
+ metrics = {}
464
+ if loss_fn is not None:
465
+ metrics['loss'] = losses.result().numpy().item()
466
+ for k, acc in accuracies_topk.items():
467
+ metrics[f'acc_top{k}'] = acc.result().numpy().item() * 100
468
+ heaps = {'tp': tp_heaps, 'fp': fp_heaps, 'fn': fn_heaps}
469
+ return metrics, heaps, confusion_matrix
470
+
471
+
472
+ def log_run(split: str, epoch: int, writer: tf.summary.SummaryWriter,
473
+ label_names: Sequence[str], metrics: MutableMapping[str, float],
474
+ heaps: Mapping[str, Mapping[int, list[HeapItem]]], cm: np.ndarray
475
+ ) -> None:
476
+ """
477
+ Logs the outputs (metrics, confusion matrix, tp/fp/fn images) from a
478
+ single epoch run to Tensorboard.
479
+
480
+ Args:
481
+ metrics: dict, keys already prefixed with {split}/
482
+ """
483
+
484
+ per_class_recall = recall_from_confusion_matrix(cm, label_names)
485
+ metrics.update(prefix_all_keys(per_class_recall, f'{split}/label_recall/'))
486
+
487
+ # log metrics
488
+ for metric, value in metrics.items():
489
+ tf.summary.scalar(metric, value, epoch)
490
+
491
+ # log confusion matrix
492
+ cm_fig = plot_utils.plot_confusion_matrix(cm, classes=label_names,
493
+ normalize=True)
494
+ cm_fig_img = tf.convert_to_tensor(fig_to_img(cm_fig)[np.newaxis, ...])
495
+ tf.summary.image(f'confusion_matrix/{split}', cm_fig_img, step=epoch)
496
+
497
+ # log tp/fp/fn images
498
+ for heap_type, heap_dict in heaps.items():
499
+ log_images_with_confidence(heap_dict, label_names, epoch=epoch,
500
+ tag=f'{split}/{heap_type}')
501
+ writer.flush()
502
+
503
+
504
+ #%% Main function
505
+
506
+ def main(dataset_dir: str,
507
+ cropped_images_dir: str,
508
+ multilabel: bool,
509
+ model_name: str,
510
+ pretrained: bool,
511
+ finetune: int,
512
+ label_weighted: bool,
513
+ weight_by_detection_conf: bool | str,
514
+ epochs: int,
515
+ batch_size: int,
516
+ lr: float,
517
+ weight_decay: float,
518
+ seed: Optional[int] = None,
519
+ logdir: str = '',
520
+ cache_splits: Sequence[str] = ()) -> None:
521
+
522
+ # input validation
523
+ assert os.path.exists(dataset_dir)
524
+ assert os.path.exists(cropped_images_dir)
525
+ if isinstance(weight_by_detection_conf, str):
526
+ assert os.path.exists(weight_by_detection_conf)
527
+
528
+ # set seed
529
+ seed = np.random.randint(10_000) if seed is None else seed
530
+ np.random.seed(seed)
531
+ tf.random.set_seed(seed)
532
+
533
+ # create logdir and save params
534
+ params = dict(locals()) # make a copy
535
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') # '20200722_110816'
536
+ logdir = os.path.join(logdir, timestamp)
537
+ os.makedirs(logdir, exist_ok=True)
538
+ print('Created logdir:', logdir)
539
+ with open(os.path.join(logdir, 'params.json'), 'w') as f:
540
+ json.dump(params, f, indent=1)
541
+
542
+ gpus = tf.config.experimental.list_physical_devices('GPU')
543
+ for gpu in gpus:
544
+ tf.config.experimental.set_memory_growth(gpu, True)
545
+
546
+ img_size = EFFICIENTNET_MODELS[model_name]['img_size']
547
+
548
+ # create dataloaders and log the index_to_label mapping
549
+ loaders, label_names = create_dataloaders(
550
+ dataset_csv_path=os.path.join(dataset_dir, 'classification_ds.csv'),
551
+ label_index_json_path=os.path.join(dataset_dir, 'label_index.json'),
552
+ splits_json_path=os.path.join(dataset_dir, 'splits.json'),
553
+ cropped_images_dir=cropped_images_dir,
554
+ img_size=img_size,
555
+ multilabel=multilabel,
556
+ label_weighted=label_weighted,
557
+ weight_by_detection_conf=weight_by_detection_conf,
558
+ batch_size=batch_size,
559
+ augment_train=True,
560
+ cache_splits=cache_splits)
561
+
562
+ writer = tf.summary.create_file_writer(logdir)
563
+ writer.set_as_default()
564
+
565
+ model = build_model(
566
+ model_name, num_classes=len(label_names), img_size=img_size,
567
+ pretrained=pretrained, finetune=finetune > 0)
568
+
569
+ # define loss function and optimizer
570
+ loss_fn: tf.keras.losses.Loss
571
+ if multilabel:
572
+ loss_fn = tf.keras.losses.BinaryCrossentropy(
573
+ from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
574
+ else:
575
+ loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
576
+ from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
577
+
578
+ # using EfficientNet training defaults
579
+ # - batch norm momentum: 0.99
580
+ # - optimizer: RMSProp, decay 0.9 and momentum 0.9
581
+ # - epochs: 350
582
+ # - learning rate: 0.256, decays by 0.97 every 2.4 epochs
583
+ # - weight decay: 1e-5
584
+ lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
585
+ lr, decay_steps=1, decay_rate=0.97, staircase=True)
586
+ optimizer = tf.keras.optimizers.RMSprop(
587
+ learning_rate=lr, rho=0.9, momentum=0.9)
588
+
589
+ best_epoch_metrics: dict[str, float] = {}
590
+ for epoch in range(epochs):
591
+ print(f'Epoch: {epoch}')
592
+ optimizer.learning_rate = lr_schedule(epoch)
593
+ tf.summary.scalar('lr', optimizer.learning_rate, epoch)
594
+
595
+ if epoch > 0 and finetune == epoch:
596
+ print('Turning off fine-tune!')
597
+ model.base_model.trainable = True
598
+
599
+ print('- train:')
600
+
601
+ train_metrics, train_heaps, train_cm = run_epoch(
602
+ model, loader=loaders['train'], weighted=label_weighted,
603
+ loss_fn=loss_fn, weight_decay=weight_decay, optimizer=optimizer,
604
+ finetune=finetune > epoch, return_extreme_images=True)
605
+ train_metrics = prefix_all_keys(train_metrics, prefix='train/')
606
+ log_run('train', epoch, writer, label_names,
607
+ metrics=train_metrics, heaps=train_heaps, cm=train_cm)
608
+
609
+ print('- val:')
610
+ val_metrics, val_heaps, val_cm = run_epoch(
611
+ model, loader=loaders['val'], weighted=label_weighted,
612
+ loss_fn=loss_fn, return_extreme_images=True)
613
+ val_metrics = prefix_all_keys(val_metrics, prefix='val/')
614
+ log_run('val', epoch, writer, label_names,
615
+ metrics=val_metrics, heaps=val_heaps, cm=val_cm)
616
+
617
+ if val_metrics['val/acc_top1'] > best_epoch_metrics.get('val/acc_top1', 0): # pylint: disable=line-too-long
618
+ filename = os.path.join(logdir, f'ckpt_{epoch}.h5')
619
+ print(f'New best model! Saving checkpoint to {filename}')
620
+ model.save(filename)
621
+ best_epoch_metrics.update(train_metrics)
622
+ best_epoch_metrics.update(val_metrics)
623
+ best_epoch_metrics['epoch'] = epoch
624
+
625
+ print('- test:')
626
+ test_metrics, test_heaps, test_cm = run_epoch(
627
+ model, loader=loaders['test'], weighted=label_weighted,
628
+ loss_fn=loss_fn, return_extreme_images=True)
629
+ test_metrics = prefix_all_keys(test_metrics, prefix='test/')
630
+ log_run('test', epoch, writer, label_names,
631
+ metrics=test_metrics, heaps=test_heaps, cm=test_cm)
632
+
633
+ # stop training after 8 epochs without improvement
634
+ if epoch >= best_epoch_metrics['epoch'] + 8:
635
+ break
636
+
637
+ hparams_dict = {
638
+ 'model_name': model_name,
639
+ 'multilabel': multilabel,
640
+ 'finetune': finetune,
641
+ 'batch_size': batch_size,
642
+ 'epochs': epochs
643
+ }
644
+ hp.hparams(hparams_dict)
645
+ writer.close()
646
+
647
+
648
+ #%% Command-line driver
649
+
650
+ def _parse_args() -> argparse.Namespace:
651
+ """Parses arguments."""
652
+ parser = argparse.ArgumentParser(
653
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
654
+ description='Trains classifier.')
655
+ parser.add_argument(
656
+ 'dataset_dir',
657
+ help='path to directory containing: 1) classification dataset CSV, '
658
+ '2) label index JSON, 3) splits JSON')
659
+ parser.add_argument(
660
+ 'cropped_images_dir',
661
+ help='path to local directory where image crops are saved')
662
+ parser.add_argument(
663
+ '--multilabel', action='store_true',
664
+ help='for multi-label, multi-class classification')
665
+ parser.add_argument(
666
+ '-m', '--model-name', default='efficientnet-b0',
667
+ choices=list(EFFICIENTNET_MODELS.keys()),
668
+ help='which EfficientNet model')
669
+ parser.add_argument(
670
+ '--pretrained', action='store_true',
671
+ help='start with pretrained model')
672
+ parser.add_argument(
673
+ '--finetune', type=int, default=0,
674
+ help='only fine tune the final fully-connected layer for the first '
675
+ '<finetune> epochs')
676
+ parser.add_argument(
677
+ '--label-weighted', action='store_true',
678
+ help='weight training samples to balance labels')
679
+ parser.add_argument(
680
+ '--weight-by-detection-conf', nargs='?', const=True, default=False,
681
+ help='weight training examples by detection confidence. '
682
+ 'Optionally takes a .npz file for isotonic calibration.')
683
+ parser.add_argument(
684
+ '--epochs', type=int, default=0,
685
+ help='number of epochs for training, 0 for eval-only')
686
+ parser.add_argument(
687
+ '--batch-size', type=int, default=256,
688
+ help='batch size for both training and eval')
689
+ parser.add_argument(
690
+ '--lr', type=float, default=None,
691
+ help='initial learning rate, defaults to (0.016 * batch_size / 256)')
692
+ parser.add_argument(
693
+ '--weight-decay', type=float, default=1e-5,
694
+ help='weight decay')
695
+ parser.add_argument(
696
+ '--seed', type=int,
697
+ help='random seed')
698
+ parser.add_argument(
699
+ '--logdir', default='.',
700
+ help='directory where TensorBoard logs and a params file are saved')
701
+ parser.add_argument(
702
+ '--cache', nargs='*', choices=['train', 'val', 'test'], default=(),
703
+ help='which splits of the dataset to cache')
704
+ return parser.parse_args()
705
+
706
+
707
+ if __name__ == '__main__':
708
+ args = _parse_args()
709
+ if args.lr is None:
710
+ args.lr = 0.016 * args.batch_size / 256 # based on TF models repo
711
+ main(dataset_dir=args.dataset_dir,
712
+ cropped_images_dir=args.cropped_images_dir,
713
+ multilabel=args.multilabel,
714
+ model_name=args.model_name,
715
+ pretrained=args.pretrained,
716
+ finetune=args.finetune,
717
+ label_weighted=args.label_weighted,
718
+ weight_by_detection_conf=args.weight_by_detection_conf,
719
+ epochs=args.epochs,
720
+ batch_size=args.batch_size,
721
+ lr=args.lr,
722
+ weight_decay=args.weight_decay,
723
+ seed=args.seed,
724
+ logdir=args.logdir,
725
+ cache_splits=args.cache)