awc-helpers 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- awc_helpers/__init__.py +20 -0
- awc_helpers/awc_inference.py +452 -0
- awc_helpers/math_utils.py +82 -0
- awc_helpers-0.1.1.dist-info/LICENSE +39 -0
- awc_helpers-0.1.1.dist-info/METADATA +88 -0
- awc_helpers-0.1.1.dist-info/RECORD +8 -0
- awc_helpers-0.1.1.dist-info/WHEEL +5 -0
- awc_helpers-0.1.1.dist-info/top_level.txt +1 -0
awc_helpers/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""AWC Helpers - Wildlife detection and classification inference tools."""
|
|
2
|
+
|
|
3
|
+
from .awc_inference import (
|
|
4
|
+
DetectAndClassify,
|
|
5
|
+
SpeciesClasInference,
|
|
6
|
+
format_md_detections,
|
|
7
|
+
load_classification_model,
|
|
8
|
+
)
|
|
9
|
+
from .math_utils import crop_image, pil_to_tensor
|
|
10
|
+
|
|
11
|
+
__version__ = "0.1.0"
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"DetectAndClassify",
|
|
15
|
+
"SpeciesClasInference",
|
|
16
|
+
"format_md_detections",
|
|
17
|
+
"load_classification_model",
|
|
18
|
+
"crop_image",
|
|
19
|
+
"pil_to_tensor",
|
|
20
|
+
]
|
|
@@ -0,0 +1,452 @@
|
|
|
1
|
+
|
|
2
|
+
"""
|
|
3
|
+
Author: Quan Tran (Australian Wildlife Conservancy)
|
|
4
|
+
Wildlife species detection and classification inference module.
|
|
5
|
+
|
|
6
|
+
This module provides classes and functions for running inference pipelines
|
|
7
|
+
that combine MegaDetector-based animal detection with species classification
|
|
8
|
+
using fine-tuned image classification models (from timm library).
|
|
9
|
+
|
|
10
|
+
Classes:
|
|
11
|
+
SpeciesClasInference: Run species classification on pre-detected animal crops.
|
|
12
|
+
DetectAndClassify: End-to-end pipeline combining detection and classification.
|
|
13
|
+
|
|
14
|
+
Functions:
|
|
15
|
+
format_md_detections: Format MegaDetector outputs for classification input.
|
|
16
|
+
load_classification_model: Load a timm-based classification model.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from zoneinfo import ZoneInfo
|
|
20
|
+
import datetime
|
|
21
|
+
import timm
|
|
22
|
+
import torch
|
|
23
|
+
import numpy as np
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
from megadetector.detection import run_detector
|
|
26
|
+
from typing import List, Tuple, Union
|
|
27
|
+
from PIL import Image
|
|
28
|
+
from .math_utils import crop_image, pil_to_tensor
|
|
29
|
+
import logging
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def format_md_detections(md_result: dict,
|
|
36
|
+
filter_category: str = 'animal',
|
|
37
|
+
for_clas: bool = True) -> List:
|
|
38
|
+
"""
|
|
39
|
+
Format MegaDetector outputs for classification input or other uses.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
md_result: Dictionary containing MegaDetector detection results with keys
|
|
43
|
+
'file', 'detections', and optionally 'PIL' for in-memory images.
|
|
44
|
+
filter_category: Category to filter detections by (e.g., 'animal', 'person').
|
|
45
|
+
If None or empty, all detections are included.
|
|
46
|
+
for_clas: If True, format output for classification pipeline input.
|
|
47
|
+
Otherwise, format output with full detection metadata.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
List of formatted detection results. Format depends on `for_clas`:
|
|
51
|
+
- If for_clas=True: List of (img_path, bbox_confidence, bbox) or (PIL, identifier, bbox_confidence, bbox) tuples
|
|
52
|
+
- If for_clas=False: List of [file, category, bbox_confidence, bbox] lists
|
|
53
|
+
"""
|
|
54
|
+
md_animal_id = next((k for k, v in run_detector.DEFAULT_DETECTOR_LABEL_MAP.items() if v == filter_category), None)
|
|
55
|
+
results=[]
|
|
56
|
+
img_file = md_result['file']
|
|
57
|
+
if 'detections' in md_result and md_result['detections'] is not None and len(md_result['detections'])>0:
|
|
58
|
+
for i,_d in enumerate(md_result['detections']):
|
|
59
|
+
if not filter_category or _d['category'] == md_animal_id:
|
|
60
|
+
if for_clas:
|
|
61
|
+
if 'PIL' in md_result:
|
|
62
|
+
results.append((md_result['PIL'], img_file, _d['conf'], tuple(_d['bbox'])))
|
|
63
|
+
else:
|
|
64
|
+
results.append((Path(img_file).as_posix(), _d['conf'], tuple(_d['bbox'])))
|
|
65
|
+
else:
|
|
66
|
+
results.append([Path(img_file).as_posix(),_d['category'],_d['conf'],tuple(_d['bbox'])])
|
|
67
|
+
return results
|
|
68
|
+
|
|
69
|
+
def load_classification_model(
|
|
70
|
+
finetuned_model: str = None,
|
|
71
|
+
classification_model: str = 'tf_efficientnet_b5.ns_jft_in1k',
|
|
72
|
+
label_info: Union[List[str], int] = None
|
|
73
|
+
):
|
|
74
|
+
"""
|
|
75
|
+
Load a timm-based image classification model.
|
|
76
|
+
|
|
77
|
+
Creates a classification model using the timm library, optionally loading
|
|
78
|
+
fine-tuned weights from a checkpoint file.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
finetuned_model: Path to fine-tuned model weights (.pth file).
|
|
82
|
+
If None, loads pretrained ImageNet weights.
|
|
83
|
+
classification_model: Name of the timm model architecture.
|
|
84
|
+
Hyphens are automatically converted to underscores.
|
|
85
|
+
label_info: Either a list of class label names, or an integer
|
|
86
|
+
specifying the number of output classes.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
torch.nn.Module: The loaded classification model.
|
|
90
|
+
|
|
91
|
+
Raises:
|
|
92
|
+
FileNotFoundError: If finetuned_model path does not exist.
|
|
93
|
+
"""
|
|
94
|
+
# Convert model name format for timm
|
|
95
|
+
timm_model_name = classification_model.replace('-', '_')
|
|
96
|
+
num_classes = label_info if isinstance(label_info, int) else len(label_info)
|
|
97
|
+
if finetuned_model is not None:
|
|
98
|
+
# Create model with timm (without pretrained weights)
|
|
99
|
+
model = timm.create_model(timm_model_name, pretrained=False, num_classes=num_classes)
|
|
100
|
+
# Load fine-tuned weights
|
|
101
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
102
|
+
state_dict = torch.load(finetuned_model, map_location=device)
|
|
103
|
+
ret = model.load_state_dict(state_dict, strict=False)
|
|
104
|
+
if len(ret.missing_keys):
|
|
105
|
+
logger.warning(f'Missing weights: {ret.missing_keys}')
|
|
106
|
+
if len(ret.unexpected_keys):
|
|
107
|
+
logger.warning(f'Unexpected weights: {ret.unexpected_keys}')
|
|
108
|
+
logger.info(f'Loaded finetuned timm classification model: {Path(finetuned_model).name} with {num_classes} classes')
|
|
109
|
+
else:
|
|
110
|
+
model = timm.create_model(timm_model_name, pretrained=True, num_classes=num_classes)
|
|
111
|
+
logger.info(f'Loaded pretrained timm classification model: {timm_model_name} with {num_classes} classes')
|
|
112
|
+
return model
|
|
113
|
+
|
|
114
|
+
class SpeciesClasInference:
|
|
115
|
+
"""
|
|
116
|
+
Species classification inference engine for wildlife images.
|
|
117
|
+
|
|
118
|
+
This class handles loading a classification model and running inference
|
|
119
|
+
on cropped animal detections. It supports batch processing, GPU acceleration,
|
|
120
|
+
and mixed-precision inference.
|
|
121
|
+
|
|
122
|
+
Attributes:
|
|
123
|
+
device: PyTorch device (cuda or cpu) for model inference.
|
|
124
|
+
model: The loaded classification model.
|
|
125
|
+
label_names: List of class label names.
|
|
126
|
+
clas_threshold: Minimum confidence threshold for predictions.
|
|
127
|
+
pred_topn: Number of top predictions to return per image.
|
|
128
|
+
prob_round: Decimal places to round probabilities.
|
|
129
|
+
use_fp16: Whether to use FP16 mixed precision inference.
|
|
130
|
+
resize_size: Target size for image resizing before inference.
|
|
131
|
+
skip_errors: Whether to skip images that fail to process.
|
|
132
|
+
|
|
133
|
+
Example:
|
|
134
|
+
>>> classifier = SpeciesClasInference(
|
|
135
|
+
... classifier_path='model.pth',
|
|
136
|
+
... classifier_base='tf_efficientnet_b5.ns_jft_in1k',
|
|
137
|
+
... label_names=['cat', 'dog', 'bird']
|
|
138
|
+
... )
|
|
139
|
+
>>> results = classifier.predict_batch([(image_path, bbox)])
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
def __init__(self,
|
|
143
|
+
classifier_path: str,
|
|
144
|
+
classifier_base: str,
|
|
145
|
+
label_names: List[str] = None,
|
|
146
|
+
pred_topn: int = 1,
|
|
147
|
+
prob_round: int = 4,
|
|
148
|
+
clas_threshold: float = 0.5,
|
|
149
|
+
resize_size: int = 300,
|
|
150
|
+
force_cpu: bool = False,
|
|
151
|
+
use_fp16: bool = False,
|
|
152
|
+
skip_errors: bool = True):
|
|
153
|
+
"""
|
|
154
|
+
Initialize the species classification inference engine.
|
|
155
|
+
"""
|
|
156
|
+
if torch.cuda.is_available() and not force_cpu:
|
|
157
|
+
self.device = torch.device('cuda')
|
|
158
|
+
logger.info(f"\tGPU Device: {torch.cuda.get_device_name()}")
|
|
159
|
+
else:
|
|
160
|
+
self.device = torch.device('cpu')
|
|
161
|
+
|
|
162
|
+
self.model = load_classification_model(finetuned_model=classifier_path,
|
|
163
|
+
classification_model=classifier_base,
|
|
164
|
+
label_info=label_names)
|
|
165
|
+
self.label_names = label_names
|
|
166
|
+
self.model = self.model.to(self.device)
|
|
167
|
+
self.model.eval()
|
|
168
|
+
|
|
169
|
+
self.clas_threshold = clas_threshold
|
|
170
|
+
self.pred_topn=pred_topn
|
|
171
|
+
self.prob_round=prob_round
|
|
172
|
+
self.use_fp16=use_fp16 and self.device.type=='cuda'
|
|
173
|
+
self.resize_size=resize_size
|
|
174
|
+
self.skip_errors=skip_errors
|
|
175
|
+
|
|
176
|
+
def _prepare_crop(
|
|
177
|
+
self,
|
|
178
|
+
source: Union[str, Image.Image],
|
|
179
|
+
bbox_norm: Tuple[float, float, float, float]
|
|
180
|
+
) -> Image.Image:
|
|
181
|
+
"""
|
|
182
|
+
Load (if path) and crop image to bounding box.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
source: Image path string OR PIL Image
|
|
186
|
+
bbox_norm: Normalized bounding box (x_min, y_min, width, height)
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Cropped RGB PIL Image
|
|
190
|
+
"""
|
|
191
|
+
if isinstance(source, str):
|
|
192
|
+
with Image.open(source) as img:
|
|
193
|
+
img.load()
|
|
194
|
+
img = img.convert('RGB') if img.mode != 'RGB' else img
|
|
195
|
+
return crop_image(img, bbox_norm, square_crop=True)
|
|
196
|
+
else:
|
|
197
|
+
img = source.convert('RGB') if source.mode != 'RGB' else source
|
|
198
|
+
return crop_image(img, bbox_norm, square_crop=True)
|
|
199
|
+
|
|
200
|
+
def _predict(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
201
|
+
"""
|
|
202
|
+
Run classification model on input tensor.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
input_tensor: Tensor of shape (B, 3, H, W)
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
Tuple of
|
|
209
|
+
- Probabilities tensor of shape (B, num_classes)
|
|
210
|
+
- Indices tensor of shape (B, num_classes)
|
|
211
|
+
"""
|
|
212
|
+
with torch.no_grad():
|
|
213
|
+
if self.use_fp16:
|
|
214
|
+
with torch.amp.autocast('cuda'):
|
|
215
|
+
logits = self.model(input_tensor)
|
|
216
|
+
else:
|
|
217
|
+
logits = self.model(input_tensor)
|
|
218
|
+
|
|
219
|
+
# Softmax in fp32 for numerical stability
|
|
220
|
+
probs = torch.nn.functional.softmax(logits.float(), dim=1)
|
|
221
|
+
|
|
222
|
+
top_probs, top_indices = torch.topk(probs, k=self.pred_topn, dim=1)
|
|
223
|
+
return (top_probs.cpu().numpy().round(self.prob_round),
|
|
224
|
+
top_indices.cpu().numpy())
|
|
225
|
+
|
|
226
|
+
def _format_output(self,
|
|
227
|
+
identifier: Union[str, None],
|
|
228
|
+
bbox_conf: float,
|
|
229
|
+
bbox_norm: Tuple[float, float, float, float],
|
|
230
|
+
top_probs: np.ndarray,
|
|
231
|
+
top_indices: np.ndarray):
|
|
232
|
+
"""
|
|
233
|
+
Build output result tuple from predictions.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
identifier: Image path or custom ID
|
|
237
|
+
bbox_conf: Confidence score of the bounding box
|
|
238
|
+
bbox_norm: Bounding box used
|
|
239
|
+
top_probs: Top-k probabilities (1D array)
|
|
240
|
+
top_indices: Top-k class indices (1D array)
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
Tuple: (identifier, bbox_conf, bbox, label1, prob1, label2, prob2, ...)
|
|
244
|
+
"""
|
|
245
|
+
result = [identifier, bbox_conf, bbox_norm]
|
|
246
|
+
for k in range(len(top_indices)):
|
|
247
|
+
label = self.label_names[top_indices[k]]
|
|
248
|
+
prob = round(float(top_probs[k]), self.prob_round)
|
|
249
|
+
if prob >= self.clas_threshold:
|
|
250
|
+
result.extend([label, prob])
|
|
251
|
+
|
|
252
|
+
return tuple(result)
|
|
253
|
+
|
|
254
|
+
def predict_batch(
|
|
255
|
+
self,
|
|
256
|
+
inputs: List[Union[Tuple[str, float, Tuple[float, float, float, float]], Tuple[Image.Image, str, float, Tuple[float, float, float, float]]]],
|
|
257
|
+
batch_size: int = 1,
|
|
258
|
+
) -> List[Tuple]:
|
|
259
|
+
"""
|
|
260
|
+
Run inference on a batch of inputs.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
inputs: List of (img_path, bbox_confidence, bbox) tuples, or (PIL Image, id, bbox_confidence, bbox) tuples for streaming
|
|
264
|
+
pred_topn: Number of top predictions to return
|
|
265
|
+
prob_round: Decimal places to round probabilities
|
|
266
|
+
batch_size: Number of images to process at once
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
List of result tuples
|
|
270
|
+
"""
|
|
271
|
+
results = []
|
|
272
|
+
|
|
273
|
+
for batch_start in range(0, len(inputs), batch_size):
|
|
274
|
+
batch_inputs = inputs[batch_start:batch_start + batch_size]
|
|
275
|
+
|
|
276
|
+
# Preprocess batch
|
|
277
|
+
batch_tensors = []
|
|
278
|
+
batch_metadata = [] # (identifier, bbox_conf, bbox)
|
|
279
|
+
|
|
280
|
+
for *sources, bbox in batch_inputs:
|
|
281
|
+
try:
|
|
282
|
+
identifier = sources[0] if len(sources)==2 else sources[1]
|
|
283
|
+
bbox_conf = round(sources[-1], self.prob_round)
|
|
284
|
+
cropped = self._prepare_crop(sources[0], bbox)
|
|
285
|
+
tensor = pil_to_tensor(cropped,resize_size=self.resize_size).to(self.device)
|
|
286
|
+
batch_tensors.append(tensor)
|
|
287
|
+
batch_metadata.append((identifier, bbox_conf, bbox))
|
|
288
|
+
except Exception as e:
|
|
289
|
+
if self.skip_errors:
|
|
290
|
+
logger.warning(f"Failed to process {identifier}: {e}")
|
|
291
|
+
continue
|
|
292
|
+
raise
|
|
293
|
+
|
|
294
|
+
if not batch_tensors:
|
|
295
|
+
continue
|
|
296
|
+
|
|
297
|
+
# Stack and run inference
|
|
298
|
+
batch_tensor = torch.cat(batch_tensors, dim=0)
|
|
299
|
+
top_probs, top_indices = self._predict(batch_tensor)
|
|
300
|
+
|
|
301
|
+
for i, (identifier, bbox_conf, bbox) in enumerate(batch_metadata):
|
|
302
|
+
result = self._format_output(
|
|
303
|
+
identifier, bbox_conf, bbox,
|
|
304
|
+
top_probs[i], top_indices[i],
|
|
305
|
+
)
|
|
306
|
+
results.append(result)
|
|
307
|
+
|
|
308
|
+
return results
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
class DetectAndClassify:
|
|
312
|
+
"""
|
|
313
|
+
End-to-end wildlife detection and classification pipeline.
|
|
314
|
+
|
|
315
|
+
Combines MegaDetector for animal detection with a species classifier
|
|
316
|
+
to provide a complete inference pipeline from raw images to species
|
|
317
|
+
predictions.
|
|
318
|
+
|
|
319
|
+
Attributes:
|
|
320
|
+
md_detector: MegaDetector model instance for animal detection.
|
|
321
|
+
clas_inference: SpeciesClasInference instance for classification.
|
|
322
|
+
detection_threshold: Minimum confidence for detection filtering.
|
|
323
|
+
|
|
324
|
+
Example:
|
|
325
|
+
>>> pipeline = DetectAndClassify(
|
|
326
|
+
... detector_path='md_v5a.0.0.pt',
|
|
327
|
+
... classifier_path='species_model.pth',
|
|
328
|
+
... label_names=['kangaroo', 'wallaby', 'wombat']
|
|
329
|
+
... )
|
|
330
|
+
>>> results = pipeline.predict('wildlife_image.jpg')
|
|
331
|
+
"""
|
|
332
|
+
|
|
333
|
+
def __init__(self,
|
|
334
|
+
detector_path: str,
|
|
335
|
+
classifier_path: str,
|
|
336
|
+
label_names: List[str],
|
|
337
|
+
classifier_base: str = 'tf_efficientnet_b5.ns_jft_in1k',
|
|
338
|
+
detection_threshold: float = 0.1,
|
|
339
|
+
clas_threshold: float = 0.5,
|
|
340
|
+
pred_topn: int = 1,
|
|
341
|
+
resize_size: int = 300,
|
|
342
|
+
force_cpu: bool = False,
|
|
343
|
+
skip_clas_errors: bool = True):
|
|
344
|
+
"""
|
|
345
|
+
Initialize the detection and classification pipeline.
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
detector_path: Path to the MegaDetector model weights.
|
|
349
|
+
classifier_path: Path to the species classifier weights.
|
|
350
|
+
label_names: List of species class names.
|
|
351
|
+
classifier_base: Name of the base timm model architecture.
|
|
352
|
+
detection_threshold: Minimum confidence for animal detections.
|
|
353
|
+
clas_threshold: Minimum confidence for classification predictions.
|
|
354
|
+
pred_topn: Number of top classification predictions to return.
|
|
355
|
+
resize_size: Target image size for classification model input.
|
|
356
|
+
force_cpu: If True, use CPU even if CUDA is available.
|
|
357
|
+
skip_clas_errors: If True, skip classification errors instead of raising.
|
|
358
|
+
"""
|
|
359
|
+
self.md_detector = run_detector.load_detector(str(detector_path),
|
|
360
|
+
force_cpu=force_cpu)
|
|
361
|
+
self.clas_inference = SpeciesClasInference(classifier_path=classifier_path,
|
|
362
|
+
classifier_base=classifier_base,
|
|
363
|
+
clas_threshold=clas_threshold,
|
|
364
|
+
label_names=label_names,
|
|
365
|
+
pred_topn=pred_topn,
|
|
366
|
+
resize_size=resize_size,
|
|
367
|
+
force_cpu=force_cpu,
|
|
368
|
+
skip_errors=skip_clas_errors)
|
|
369
|
+
self.detection_threshold = detection_threshold
|
|
370
|
+
|
|
371
|
+
def _validate_input(
|
|
372
|
+
self,
|
|
373
|
+
inp: Union[str, Image.Image, List[Union[str, Image.Image]]],
|
|
374
|
+
identifier: Union[str, List[str], None]
|
|
375
|
+
) -> Tuple[List, List]:
|
|
376
|
+
"""
|
|
377
|
+
Validate and normalize input images and identifiers.
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
inp: Single image or list of images (paths or PIL Images).
|
|
381
|
+
identifier: Optional identifier(s) for the images. If None,
|
|
382
|
+
uses file paths for string inputs or timestamps for PIL images.
|
|
383
|
+
|
|
384
|
+
Returns:
|
|
385
|
+
Tuple of (normalized_inputs, normalized_identifiers)
|
|
386
|
+
|
|
387
|
+
Raises:
|
|
388
|
+
AssertionError: If identifier list length doesn't match input list length.
|
|
389
|
+
"""
|
|
390
|
+
if not isinstance(inp, (list, tuple)):
|
|
391
|
+
inp = [inp]
|
|
392
|
+
elif len(inp)==0:
|
|
393
|
+
return [],[]
|
|
394
|
+
if identifier is None:
|
|
395
|
+
if isinstance(inp[0], str):
|
|
396
|
+
identifier = inp
|
|
397
|
+
else:
|
|
398
|
+
# identifier based on date+time in human readable format, utc time
|
|
399
|
+
now = datetime.datetime.now(ZoneInfo("Australia/Perth"))
|
|
400
|
+
now_str = now.strftime("%Y%m%d_%H%M%S") + f"_{now.microsecond // 1000:03d}"
|
|
401
|
+
identifier = [now_str] if len(inp) == 1 else [f'{now_str}_{i+1}' for i in range(len(inp))]
|
|
402
|
+
|
|
403
|
+
elif not isinstance(identifier, (list, tuple)):
|
|
404
|
+
identifier = [identifier]
|
|
405
|
+
|
|
406
|
+
assert len(identifier) == len(inp), "Length of identifier list (containing e.g. image names) must match length of input list."
|
|
407
|
+
return inp, identifier
|
|
408
|
+
|
|
409
|
+
def predict(
|
|
410
|
+
self,
|
|
411
|
+
inp: Union[str, Image.Image, List[Union[str, Image.Image]]],
|
|
412
|
+
identifier: Union[str, List[str], None] = None,
|
|
413
|
+
clas_bs: int = 4
|
|
414
|
+
) -> List[Tuple]:
|
|
415
|
+
"""
|
|
416
|
+
Run detection and classification on input images.
|
|
417
|
+
|
|
418
|
+
Processes images through the MegaDetector to find animals, then
|
|
419
|
+
classifies each detected animal using the species classifier.
|
|
420
|
+
|
|
421
|
+
Args:
|
|
422
|
+
inp: Single image or list of images. Can be file paths (str)
|
|
423
|
+
or PIL Image objects.
|
|
424
|
+
identifier: Optional identifier(s) for tracking results back to
|
|
425
|
+
source images. If None, uses file paths or timestamps.
|
|
426
|
+
clas_bs: Batch size for classification inference.
|
|
427
|
+
|
|
428
|
+
Returns:
|
|
429
|
+
List of result tuples, one per detected animal. Each tuple contains:
|
|
430
|
+
(identifier, bbox, label1, prob1, label2, prob2, ...) where the
|
|
431
|
+
number of label/prob pairs depends on pred_topn and clas_threshold.
|
|
432
|
+
"""
|
|
433
|
+
inp, identifier = self._validate_input(inp, identifier)
|
|
434
|
+
if len(inp) == 0:
|
|
435
|
+
return []
|
|
436
|
+
|
|
437
|
+
md_results=[]
|
|
438
|
+
for item,id in zip(inp, identifier):
|
|
439
|
+
img = item
|
|
440
|
+
if isinstance(item,str):
|
|
441
|
+
img = Image.open(item)
|
|
442
|
+
try:
|
|
443
|
+
md_result = self.md_detector.generate_detections_one_image(img,id,
|
|
444
|
+
detection_threshold=self.detection_threshold)
|
|
445
|
+
if not isinstance(item,str):
|
|
446
|
+
md_result['PIL'] = img
|
|
447
|
+
md_results.extend(format_md_detections(md_result))
|
|
448
|
+
finally:
|
|
449
|
+
if isinstance(item,str):
|
|
450
|
+
img.close()
|
|
451
|
+
|
|
452
|
+
return self.clas_inference.predict_batch(md_results, batch_size=clas_bs)
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from PIL import Image, ImageOps
|
|
2
|
+
from typing import Sequence
|
|
3
|
+
import torch
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
def crop_image(img: Image.Image, bbox_norm: Sequence[float], square_crop: bool = True) -> Image.Image:
|
|
7
|
+
"""
|
|
8
|
+
Crop image based on normalized bounding box (MegaDetector format).
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
img: PIL Image to crop
|
|
12
|
+
bbox_norm: [x_min, y_min, width, height] all normalized 0-1
|
|
13
|
+
square_crop: Whether to make the crop square with padding
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
Cropped PIL Image
|
|
17
|
+
"""
|
|
18
|
+
img_w, img_h = img.size
|
|
19
|
+
xmin = int(bbox_norm[0] * img_w)
|
|
20
|
+
ymin = int(bbox_norm[1] * img_h)
|
|
21
|
+
box_w = int(bbox_norm[2] * img_w)
|
|
22
|
+
box_h = int(bbox_norm[3] * img_h)
|
|
23
|
+
|
|
24
|
+
if square_crop:
|
|
25
|
+
box_size = max(box_w, box_h)
|
|
26
|
+
xmin = max(0, min(xmin - int((box_size - box_w) / 2), img_w - box_w))
|
|
27
|
+
ymin = max(0, min(ymin - int((box_size - box_h) / 2), img_h - box_h))
|
|
28
|
+
box_w = min(img_w, box_size)
|
|
29
|
+
box_h = min(img_h, box_size)
|
|
30
|
+
|
|
31
|
+
if box_w == 0 or box_h == 0:
|
|
32
|
+
raise ValueError(f'Invalid crop dimensions (w={box_w}, h={box_h})')
|
|
33
|
+
|
|
34
|
+
crop = img.crop(box=[xmin, ymin, xmin + box_w, ymin + box_h])
|
|
35
|
+
|
|
36
|
+
if square_crop and (box_w != box_h):
|
|
37
|
+
crop = ImageOps.pad(crop, size=(box_size, box_size), color=0)
|
|
38
|
+
|
|
39
|
+
return crop
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def pil_to_tensor(img: Image.Image,resize_size=300) -> torch.Tensor:
|
|
43
|
+
"""
|
|
44
|
+
Convert PIL image to normalized tensor ready for model input.
|
|
45
|
+
- Use SMALLER ratio so crop fits within image
|
|
46
|
+
- Center crop to target aspect ratio
|
|
47
|
+
- Resize crop to target size
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
img: PIL Image (already cropped and in RGB)
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Tensor of shape (1, 3, H, W)
|
|
54
|
+
"""
|
|
55
|
+
target_w = target_h = resize_size
|
|
56
|
+
w, h = img.size # PIL: (width, height)
|
|
57
|
+
|
|
58
|
+
# Fastai ResizeMethod.Crop uses SMALLER ratio
|
|
59
|
+
# This ensures crop fits within the image bounds
|
|
60
|
+
ratio_w = w / target_w
|
|
61
|
+
ratio_h = h / target_h
|
|
62
|
+
m = min(ratio_w, ratio_h)
|
|
63
|
+
|
|
64
|
+
# Crop size that when resized will give target size
|
|
65
|
+
cp_w = int(m * target_w)
|
|
66
|
+
cp_h = int(m * target_h)
|
|
67
|
+
|
|
68
|
+
# Center crop position (pcts = 0.5, 0.5 for validation)
|
|
69
|
+
left = (w - cp_w) // 2
|
|
70
|
+
top = (h - cp_h) // 2
|
|
71
|
+
|
|
72
|
+
# Crop
|
|
73
|
+
img = img.crop((left, top, left + cp_w, top + cp_h))
|
|
74
|
+
|
|
75
|
+
# Resize to target
|
|
76
|
+
img = img.resize((target_w, target_h), Image.BILINEAR)
|
|
77
|
+
|
|
78
|
+
# To tensor: HWC uint8 -> NCHW float32 [0, 1]
|
|
79
|
+
img_array = np.asarray(img, dtype=np.float32) / 255.0
|
|
80
|
+
img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0)
|
|
81
|
+
|
|
82
|
+
return img_tensor
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Quan Tran
|
|
4
|
+
|
|
5
|
+
This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike
|
|
6
|
+
4.0 International License. To view a copy of this license, visit:
|
|
7
|
+
https://creativecommons.org/licenses/by-nc-sa/4.0/
|
|
8
|
+
|
|
9
|
+
You are free to:
|
|
10
|
+
|
|
11
|
+
Share — copy and redistribute the material in any medium or format
|
|
12
|
+
Adapt — remix, transform, and build upon the material
|
|
13
|
+
|
|
14
|
+
Under the following terms:
|
|
15
|
+
|
|
16
|
+
Attribution — You must give appropriate credit, provide a link to the license,
|
|
17
|
+
and indicate if changes were made. You may do so in any reasonable manner, but
|
|
18
|
+
not in any way that suggests the licensor endorses you or your use.
|
|
19
|
+
|
|
20
|
+
NonCommercial — You may not use the material for commercial purposes.
|
|
21
|
+
|
|
22
|
+
ShareAlike — If you remix, transform, or build upon the material, you must
|
|
23
|
+
distribute your contributions under the same license as the original.
|
|
24
|
+
|
|
25
|
+
No additional restrictions — You may not apply legal terms or technological
|
|
26
|
+
measures that legally restrict others from doing anything the license permits.
|
|
27
|
+
|
|
28
|
+
Notices:
|
|
29
|
+
|
|
30
|
+
You do not have to comply with the license for elements of the material in the
|
|
31
|
+
public domain or where your use is permitted by an applicable exception or
|
|
32
|
+
limitation.
|
|
33
|
+
|
|
34
|
+
No warranties are given. The license may not give you all of the permissions
|
|
35
|
+
necessary for your intended use. For example, other rights such as publicity,
|
|
36
|
+
privacy, or moral rights may limit how you use the material.
|
|
37
|
+
|
|
38
|
+
For the full legal code, see:
|
|
39
|
+
https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: awc-helpers
|
|
3
|
+
Version: 0.1.1
|
|
4
|
+
Summary: Australian Wildlife Conservancy's Wildlife detection and species classification inference tools
|
|
5
|
+
Author: Quan Tran
|
|
6
|
+
License: CC-BY-NC-SA-4.0
|
|
7
|
+
Project-URL: Homepage, https://github.com/Australian-Wildlife-Conservancy-AWC/awc_inference
|
|
8
|
+
Project-URL: Repository, https://github.com/Australian-Wildlife-Conservancy-AWC/awc_inference
|
|
9
|
+
Project-URL: Issues, https://github.com/Australian-Wildlife-Conservancy-AWC/awc_inference/issues
|
|
10
|
+
Keywords: wildlife,detection,classification,megadetector,camera-trap,ecology
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: Operating System :: OS Independent
|
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
20
|
+
Classifier: Topic :: Scientific/Engineering :: Image Recognition
|
|
21
|
+
Requires-Python: >=3.9
|
|
22
|
+
Description-Content-Type: text/markdown
|
|
23
|
+
License-File: LICENSE
|
|
24
|
+
Requires-Dist: torch==2.9.1
|
|
25
|
+
Requires-Dist: megadetector<=10.0.17
|
|
26
|
+
Requires-Dist: ultralytics<=8.4.7
|
|
27
|
+
Requires-Dist: timm<=1.0.24
|
|
28
|
+
|
|
29
|
+
# AWC Helpers
|
|
30
|
+
|
|
31
|
+
Wildlife detection and species classification inference tools combining MegaDetector with custom species classifiers.
|
|
32
|
+
|
|
33
|
+
## Installation
|
|
34
|
+
|
|
35
|
+
### 1. Install PyTorch
|
|
36
|
+
|
|
37
|
+
**Windows (with CUDA GPU):**
|
|
38
|
+
```bash
|
|
39
|
+
pip install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
**Linux / Mac / CPU:**
|
|
43
|
+
```bash
|
|
44
|
+
pip install torch==2.9.1
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
### 2. Install AWC Helpers
|
|
48
|
+
|
|
49
|
+
**From PyPI:**
|
|
50
|
+
```bash
|
|
51
|
+
pip install awc-helpers
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
**From GitHub:**
|
|
55
|
+
```bash
|
|
56
|
+
pip install git+https://github.com/Australian-Wildlife-Conservancy-AWC/awc_inference.git
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
## Usage
|
|
60
|
+
|
|
61
|
+
```python
|
|
62
|
+
from awc_helpers import DetectAndClassify
|
|
63
|
+
|
|
64
|
+
# Initialize the pipeline
|
|
65
|
+
pipeline = DetectAndClassify(
|
|
66
|
+
detector_path="path/to/megadetector.pt",
|
|
67
|
+
classifier_path="path/to/species_classifier.pth",
|
|
68
|
+
label_names=["species_a", "species_b", "species_c"],
|
|
69
|
+
detection_threshold=0.1,
|
|
70
|
+
clas_threshold=0.5,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Run inference on image paths
|
|
74
|
+
results = pipeline.predict(
|
|
75
|
+
inp=["image1.jpg", "image2.jpg"],
|
|
76
|
+
clas_bs=4
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# Results format: [(identifier, bbox, label, confidence), ...]
|
|
80
|
+
for result in results:
|
|
81
|
+
print(result)
|
|
82
|
+
```
|
|
83
|
+
|
|
84
|
+
## License
|
|
85
|
+
|
|
86
|
+
This project is licensed under the [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License](LICENSE) (CC BY-NC-SA 4.0).
|
|
87
|
+
|
|
88
|
+
**Non-commercial use only. Derivative works must use the same license.**
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
awc_helpers/__init__.py,sha256=xStt2_-5rK5MvKtbFjMr6w2GqibT7EZWg7SgXmJ_gS8,474
|
|
2
|
+
awc_helpers/awc_inference.py,sha256=x3Lrj7txQ5oaz7rpU_6aKoCIJYcFVyulUH7AMDBDIus,19533
|
|
3
|
+
awc_helpers/math_utils.py,sha256=W3G5PGVkiMLM31qjsdJhmJq6KY46XZr6rj1SXWjK4Gk,2656
|
|
4
|
+
awc_helpers-0.1.1.dist-info/LICENSE,sha256=-K8JM-Ym5RhIZI1Wh5soC8hkvOhgA4R0fNpbpucPizg,1659
|
|
5
|
+
awc_helpers-0.1.1.dist-info/METADATA,sha256=t775UXQUnRNco6chjA8mQ50YcqRoupB-9gwMVxHEyoc,2710
|
|
6
|
+
awc_helpers-0.1.1.dist-info/WHEEL,sha256=hPN0AlP2dZM_3ZJZWP4WooepkmU9wzjGgCLCeFjkHLA,92
|
|
7
|
+
awc_helpers-0.1.1.dist-info/top_level.txt,sha256=_Xvw_DTZwJ6szEkcXxA1_AGYwKN-Uj6jfz0bF07-S8M,12
|
|
8
|
+
awc_helpers-0.1.1.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
awc_helpers
|