awc-helpers 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,39 @@
1
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License
2
+
3
+ Copyright (c) 2026 Quan Tran
4
+
5
+ This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike
6
+ 4.0 International License. To view a copy of this license, visit:
7
+ https://creativecommons.org/licenses/by-nc-sa/4.0/
8
+
9
+ You are free to:
10
+
11
+ Share — copy and redistribute the material in any medium or format
12
+ Adapt — remix, transform, and build upon the material
13
+
14
+ Under the following terms:
15
+
16
+ Attribution — You must give appropriate credit, provide a link to the license,
17
+ and indicate if changes were made. You may do so in any reasonable manner, but
18
+ not in any way that suggests the licensor endorses you or your use.
19
+
20
+ NonCommercial — You may not use the material for commercial purposes.
21
+
22
+ ShareAlike — If you remix, transform, or build upon the material, you must
23
+ distribute your contributions under the same license as the original.
24
+
25
+ No additional restrictions — You may not apply legal terms or technological
26
+ measures that legally restrict others from doing anything the license permits.
27
+
28
+ Notices:
29
+
30
+ You do not have to comply with the license for elements of the material in the
31
+ public domain or where your use is permitted by an applicable exception or
32
+ limitation.
33
+
34
+ No warranties are given. The license may not give you all of the permissions
35
+ necessary for your intended use. For example, other rights such as publicity,
36
+ privacy, or moral rights may limit how you use the material.
37
+
38
+ For the full legal code, see:
39
+ https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
@@ -0,0 +1,3 @@
1
+ include LICENSE
2
+ include README.md
3
+ recursive-include awc_helpers *.py
@@ -0,0 +1,88 @@
1
+ Metadata-Version: 2.1
2
+ Name: awc_helpers
3
+ Version: 0.1.2
4
+ Summary: Australian Wildlife Conservancy's Wildlife detection and species classification inference tools
5
+ Author: Quan Tran
6
+ License: CC-BY-NC-SA-4.0
7
+ Project-URL: Homepage, https://github.com/Australian-Wildlife-Conservancy-AWC/awc_inference
8
+ Project-URL: Repository, https://github.com/Australian-Wildlife-Conservancy-AWC/awc_inference
9
+ Project-URL: Issues, https://github.com/Australian-Wildlife-Conservancy-AWC/awc_inference/issues
10
+ Keywords: wildlife,detection,classification,megadetector,camera-trap,ecology
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Science/Research
13
+ Classifier: Operating System :: OS Independent
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3.9
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Programming Language :: Python :: 3.13
20
+ Classifier: Topic :: Scientific/Engineering :: Image Recognition
21
+ Requires-Python: >=3.9
22
+ Description-Content-Type: text/markdown
23
+ License-File: LICENSE
24
+ Requires-Dist: torch==2.9.1
25
+ Requires-Dist: megadetector<=10.0.17
26
+ Requires-Dist: ultralytics<=8.4.7
27
+ Requires-Dist: timm<=1.0.24
28
+
29
+ # AWC Helpers
30
+
31
+ Wildlife detection and species classification inference tools combining MegaDetector with custom species classifiers.
32
+
33
+ ## Installation
34
+
35
+ ### 1. Install PyTorch
36
+
37
+ **Windows (with CUDA GPU):**
38
+ ```bash
39
+ pip install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128
40
+ ```
41
+
42
+ **Linux / Mac / CPU:**
43
+ ```bash
44
+ pip install torch==2.9.1
45
+ ```
46
+
47
+ ### 2. Install AWC Helpers
48
+
49
+ **From PyPI:**
50
+ ```bash
51
+ pip install awc-helpers
52
+ ```
53
+
54
+ **From GitHub:**
55
+ ```bash
56
+ pip install git+https://github.com/Australian-Wildlife-Conservancy-AWC/awc_inference.git
57
+ ```
58
+
59
+ ## Usage
60
+
61
+ ```python
62
+ from awc_helpers import DetectAndClassify
63
+
64
+ # Initialize the pipeline
65
+ pipeline = DetectAndClassify(
66
+ detector_path="path/to/megadetector.pt",
67
+ classifier_path="path/to/species_classifier.pth",
68
+ label_names=["species_a", "species_b", "species_c"],
69
+ detection_threshold=0.1,
70
+ clas_threshold=0.5,
71
+ )
72
+
73
+ # Run inference on image paths
74
+ results = pipeline.predict(
75
+ inp=["image1.jpg", "image2.jpg"],
76
+ clas_bs=4
77
+ )
78
+
79
+ # Results format: [(identifier, bbox, label, confidence), ...]
80
+ for result in results:
81
+ print(result)
82
+ ```
83
+
84
+ ## License
85
+
86
+ This project is licensed under the [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License](LICENSE) (CC BY-NC-SA 4.0).
87
+
88
+ **Non-commercial use only. Derivative works must use the same license.**
@@ -0,0 +1,60 @@
1
+ # AWC Helpers
2
+
3
+ Wildlife detection and species classification inference tools combining MegaDetector with custom species classifiers.
4
+
5
+ ## Installation
6
+
7
+ ### 1. Install PyTorch
8
+
9
+ **Windows (with CUDA GPU):**
10
+ ```bash
11
+ pip install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128
12
+ ```
13
+
14
+ **Linux / Mac / CPU:**
15
+ ```bash
16
+ pip install torch==2.9.1
17
+ ```
18
+
19
+ ### 2. Install AWC Helpers
20
+
21
+ **From PyPI:**
22
+ ```bash
23
+ pip install awc-helpers
24
+ ```
25
+
26
+ **From GitHub:**
27
+ ```bash
28
+ pip install git+https://github.com/Australian-Wildlife-Conservancy-AWC/awc_inference.git
29
+ ```
30
+
31
+ ## Usage
32
+
33
+ ```python
34
+ from awc_helpers import DetectAndClassify
35
+
36
+ # Initialize the pipeline
37
+ pipeline = DetectAndClassify(
38
+ detector_path="path/to/megadetector.pt",
39
+ classifier_path="path/to/species_classifier.pth",
40
+ label_names=["species_a", "species_b", "species_c"],
41
+ detection_threshold=0.1,
42
+ clas_threshold=0.5,
43
+ )
44
+
45
+ # Run inference on image paths
46
+ results = pipeline.predict(
47
+ inp=["image1.jpg", "image2.jpg"],
48
+ clas_bs=4
49
+ )
50
+
51
+ # Results format: [(identifier, bbox, label, confidence), ...]
52
+ for result in results:
53
+ print(result)
54
+ ```
55
+
56
+ ## License
57
+
58
+ This project is licensed under the [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License](LICENSE) (CC BY-NC-SA 4.0).
59
+
60
+ **Non-commercial use only. Derivative works must use the same license.**
@@ -0,0 +1,20 @@
1
+ """AWC Helpers - Wildlife detection and classification inference tools."""
2
+
3
+ from .awc_inference import (
4
+ DetectAndClassify,
5
+ SpeciesClasInference,
6
+ format_md_detections,
7
+ load_classification_model,
8
+ )
9
+ from .math_utils import crop_image, pil_to_tensor
10
+
11
+ __version__ = "0.1.2"
12
+
13
+ __all__ = [
14
+ "DetectAndClassify",
15
+ "SpeciesClasInference",
16
+ "format_md_detections",
17
+ "load_classification_model",
18
+ "crop_image",
19
+ "pil_to_tensor",
20
+ ]
@@ -0,0 +1,452 @@
1
+
2
+ """
3
+ Author: Quan Tran (Australian Wildlife Conservancy)
4
+ Wildlife species detection and classification inference module.
5
+
6
+ This module provides classes and functions for running inference pipelines
7
+ that combine MegaDetector-based animal detection with species classification
8
+ using fine-tuned image classification models (from timm library).
9
+
10
+ Classes:
11
+ SpeciesClasInference: Run species classification on pre-detected animal crops.
12
+ DetectAndClassify: End-to-end pipeline combining detection and classification.
13
+
14
+ Functions:
15
+ format_md_detections: Format MegaDetector outputs for classification input.
16
+ load_classification_model: Load a timm-based classification model.
17
+ """
18
+
19
+ from zoneinfo import ZoneInfo
20
+ import datetime
21
+ import timm
22
+ import torch
23
+ import numpy as np
24
+ from pathlib import Path
25
+ from megadetector.detection import run_detector
26
+ from typing import List, Tuple, Union
27
+ from PIL import Image
28
+ from .math_utils import crop_image, pil_to_tensor
29
+ import logging
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+
35
+ def format_md_detections(md_result: dict,
36
+ filter_category: str = 'animal',
37
+ for_clas: bool = True) -> List:
38
+ """
39
+ Format MegaDetector outputs for classification input or other uses.
40
+
41
+ Args:
42
+ md_result: Dictionary containing MegaDetector detection results with keys
43
+ 'file', 'detections', and optionally 'PIL' for in-memory images.
44
+ filter_category: Category to filter detections by (e.g., 'animal', 'person').
45
+ If None or empty, all detections are included.
46
+ for_clas: If True, format output for classification pipeline input.
47
+ Otherwise, format output with full detection metadata.
48
+
49
+ Returns:
50
+ List of formatted detection results. Format depends on `for_clas`:
51
+ - If for_clas=True: List of (img_path, bbox_confidence, bbox) or (PIL, identifier, bbox_confidence, bbox) tuples
52
+ - If for_clas=False: List of [file, category, bbox_confidence, bbox] lists
53
+ """
54
+ md_animal_id = next((k for k, v in run_detector.DEFAULT_DETECTOR_LABEL_MAP.items() if v == filter_category), None)
55
+ results=[]
56
+ img_file = md_result['file']
57
+ if 'detections' in md_result and md_result['detections'] is not None and len(md_result['detections'])>0:
58
+ for i,_d in enumerate(md_result['detections']):
59
+ if not filter_category or _d['category'] == md_animal_id:
60
+ if for_clas:
61
+ if 'PIL' in md_result:
62
+ results.append((md_result['PIL'], img_file, _d['conf'], tuple(_d['bbox'])))
63
+ else:
64
+ results.append((Path(img_file).as_posix(), _d['conf'], tuple(_d['bbox'])))
65
+ else:
66
+ results.append([Path(img_file).as_posix(),_d['category'],_d['conf'],tuple(_d['bbox'])])
67
+ return results
68
+
69
+ def load_classification_model(
70
+ finetuned_model: str = None,
71
+ classification_model: str = 'tf_efficientnet_b5.ns_jft_in1k',
72
+ label_info: Union[List[str], int] = None
73
+ ):
74
+ """
75
+ Load a timm-based image classification model.
76
+
77
+ Creates a classification model using the timm library, optionally loading
78
+ fine-tuned weights from a checkpoint file.
79
+
80
+ Args:
81
+ finetuned_model: Path to fine-tuned model weights (.pth file).
82
+ If None, loads pretrained ImageNet weights.
83
+ classification_model: Name of the timm model architecture.
84
+ Hyphens are automatically converted to underscores.
85
+ label_info: Either a list of class label names, or an integer
86
+ specifying the number of output classes.
87
+
88
+ Returns:
89
+ torch.nn.Module: The loaded classification model.
90
+
91
+ Raises:
92
+ FileNotFoundError: If finetuned_model path does not exist.
93
+ """
94
+ # Convert model name format for timm
95
+ timm_model_name = classification_model.replace('-', '_')
96
+ num_classes = label_info if isinstance(label_info, int) else len(label_info)
97
+ if finetuned_model is not None:
98
+ # Create model with timm (without pretrained weights)
99
+ model = timm.create_model(timm_model_name, pretrained=False, num_classes=num_classes)
100
+ # Load fine-tuned weights
101
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
102
+ state_dict = torch.load(finetuned_model, map_location=device)
103
+ ret = model.load_state_dict(state_dict, strict=False)
104
+ if len(ret.missing_keys):
105
+ logger.warning(f'Missing weights: {ret.missing_keys}')
106
+ if len(ret.unexpected_keys):
107
+ logger.warning(f'Unexpected weights: {ret.unexpected_keys}')
108
+ logger.info(f'Loaded finetuned timm classification model: {Path(finetuned_model).name} with {num_classes} classes')
109
+ else:
110
+ model = timm.create_model(timm_model_name, pretrained=True, num_classes=num_classes)
111
+ logger.info(f'Loaded pretrained timm classification model: {timm_model_name} with {num_classes} classes')
112
+ return model
113
+
114
+ class SpeciesClasInference:
115
+ """
116
+ Species classification inference engine for wildlife images.
117
+
118
+ This class handles loading a classification model and running inference
119
+ on cropped animal detections. It supports batch processing, GPU acceleration,
120
+ and mixed-precision inference.
121
+
122
+ Attributes:
123
+ device: PyTorch device (cuda or cpu) for model inference.
124
+ model: The loaded classification model.
125
+ label_names: List of class label names.
126
+ clas_threshold: Minimum confidence threshold for predictions.
127
+ pred_topn: Number of top predictions to return per image.
128
+ prob_round: Decimal places to round probabilities.
129
+ use_fp16: Whether to use FP16 mixed precision inference.
130
+ resize_size: Target size for image resizing before inference.
131
+ skip_errors: Whether to skip images that fail to process.
132
+
133
+ Example:
134
+ >>> classifier = SpeciesClasInference(
135
+ ... classifier_path='model.pth',
136
+ ... classifier_base='tf_efficientnet_b5.ns_jft_in1k',
137
+ ... label_names=['cat', 'dog', 'bird']
138
+ ... )
139
+ >>> results = classifier.predict_batch([(image_path, bbox)])
140
+ """
141
+
142
+ def __init__(self,
143
+ classifier_path: str,
144
+ classifier_base: str,
145
+ label_names: List[str] = None,
146
+ pred_topn: int = 1,
147
+ prob_round: int = 4,
148
+ clas_threshold: float = 0.5,
149
+ resize_size: int = 300,
150
+ force_cpu: bool = False,
151
+ use_fp16: bool = False,
152
+ skip_errors: bool = True):
153
+ """
154
+ Initialize the species classification inference engine.
155
+ """
156
+ if torch.cuda.is_available() and not force_cpu:
157
+ self.device = torch.device('cuda')
158
+ logger.info(f"\tGPU Device: {torch.cuda.get_device_name()}")
159
+ else:
160
+ self.device = torch.device('cpu')
161
+
162
+ self.model = load_classification_model(finetuned_model=classifier_path,
163
+ classification_model=classifier_base,
164
+ label_info=label_names)
165
+ self.label_names = label_names
166
+ self.model = self.model.to(self.device)
167
+ self.model.eval()
168
+
169
+ self.clas_threshold = clas_threshold
170
+ self.pred_topn=pred_topn
171
+ self.prob_round=prob_round
172
+ self.use_fp16=use_fp16 and self.device.type=='cuda'
173
+ self.resize_size=resize_size
174
+ self.skip_errors=skip_errors
175
+
176
+ def _prepare_crop(
177
+ self,
178
+ source: Union[str, Image.Image],
179
+ bbox_norm: Tuple[float, float, float, float]
180
+ ) -> Image.Image:
181
+ """
182
+ Load (if path) and crop image to bounding box.
183
+
184
+ Args:
185
+ source: Image path string OR PIL Image
186
+ bbox_norm: Normalized bounding box (x_min, y_min, width, height)
187
+
188
+ Returns:
189
+ Cropped RGB PIL Image
190
+ """
191
+ if isinstance(source, str):
192
+ with Image.open(source) as img:
193
+ img.load()
194
+ img = img.convert('RGB') if img.mode != 'RGB' else img
195
+ return crop_image(img, bbox_norm, square_crop=True)
196
+ else:
197
+ img = source.convert('RGB') if source.mode != 'RGB' else source
198
+ return crop_image(img, bbox_norm, square_crop=True)
199
+
200
+ def _predict(self, input_tensor: torch.Tensor) -> torch.Tensor:
201
+ """
202
+ Run classification model on input tensor.
203
+
204
+ Args:
205
+ input_tensor: Tensor of shape (B, 3, H, W)
206
+
207
+ Returns:
208
+ Tuple of
209
+ - Probabilities tensor of shape (B, num_classes)
210
+ - Indices tensor of shape (B, num_classes)
211
+ """
212
+ with torch.no_grad():
213
+ if self.use_fp16:
214
+ with torch.amp.autocast('cuda'):
215
+ logits = self.model(input_tensor)
216
+ else:
217
+ logits = self.model(input_tensor)
218
+
219
+ # Softmax in fp32 for numerical stability
220
+ probs = torch.nn.functional.softmax(logits.float(), dim=1)
221
+
222
+ top_probs, top_indices = torch.topk(probs, k=self.pred_topn, dim=1)
223
+ return (top_probs.cpu().numpy().round(self.prob_round),
224
+ top_indices.cpu().numpy())
225
+
226
+ def _format_output(self,
227
+ identifier: Union[str, None],
228
+ bbox_conf: float,
229
+ bbox_norm: Tuple[float, float, float, float],
230
+ top_probs: np.ndarray,
231
+ top_indices: np.ndarray):
232
+ """
233
+ Build output result tuple from predictions.
234
+
235
+ Args:
236
+ identifier: Image path or custom ID
237
+ bbox_conf: Confidence score of the bounding box
238
+ bbox_norm: Bounding box used
239
+ top_probs: Top-k probabilities (1D array)
240
+ top_indices: Top-k class indices (1D array)
241
+
242
+ Returns:
243
+ Tuple: (identifier, bbox_conf, bbox, label1, prob1, label2, prob2, ...)
244
+ """
245
+ result = [identifier, bbox_conf, bbox_norm]
246
+ for k in range(len(top_indices)):
247
+ label = self.label_names[top_indices[k]]
248
+ prob = round(float(top_probs[k]), self.prob_round)
249
+ if prob >= self.clas_threshold:
250
+ result.extend([label, prob])
251
+
252
+ return tuple(result)
253
+
254
+ def predict_batch(
255
+ self,
256
+ inputs: List[Union[Tuple[str, float, Tuple[float, float, float, float]], Tuple[Image.Image, str, float, Tuple[float, float, float, float]]]],
257
+ batch_size: int = 1,
258
+ ) -> List[Tuple]:
259
+ """
260
+ Run inference on a batch of inputs.
261
+
262
+ Args:
263
+ inputs: List of (img_path, bbox_confidence, bbox) tuples, or (PIL Image, id, bbox_confidence, bbox) tuples for streaming
264
+ pred_topn: Number of top predictions to return
265
+ prob_round: Decimal places to round probabilities
266
+ batch_size: Number of images to process at once
267
+
268
+ Returns:
269
+ List of result tuples
270
+ """
271
+ results = []
272
+
273
+ for batch_start in range(0, len(inputs), batch_size):
274
+ batch_inputs = inputs[batch_start:batch_start + batch_size]
275
+
276
+ # Preprocess batch
277
+ batch_tensors = []
278
+ batch_metadata = [] # (identifier, bbox_conf, bbox)
279
+
280
+ for *sources, bbox in batch_inputs:
281
+ try:
282
+ identifier = sources[0] if len(sources)==2 else sources[1]
283
+ bbox_conf = round(sources[-1], self.prob_round)
284
+ cropped = self._prepare_crop(sources[0], bbox)
285
+ tensor = pil_to_tensor(cropped,resize_size=self.resize_size).to(self.device)
286
+ batch_tensors.append(tensor)
287
+ batch_metadata.append((identifier, bbox_conf, bbox))
288
+ except Exception as e:
289
+ if self.skip_errors:
290
+ logger.warning(f"Failed to process {identifier}: {e}")
291
+ continue
292
+ raise
293
+
294
+ if not batch_tensors:
295
+ continue
296
+
297
+ # Stack and run inference
298
+ batch_tensor = torch.cat(batch_tensors, dim=0)
299
+ top_probs, top_indices = self._predict(batch_tensor)
300
+
301
+ for i, (identifier, bbox_conf, bbox) in enumerate(batch_metadata):
302
+ result = self._format_output(
303
+ identifier, bbox_conf, bbox,
304
+ top_probs[i], top_indices[i],
305
+ )
306
+ results.append(result)
307
+
308
+ return results
309
+
310
+
311
+ class DetectAndClassify:
312
+ """
313
+ End-to-end wildlife detection and classification pipeline.
314
+
315
+ Combines MegaDetector for animal detection with a species classifier
316
+ to provide a complete inference pipeline from raw images to species
317
+ predictions.
318
+
319
+ Attributes:
320
+ md_detector: MegaDetector model instance for animal detection.
321
+ clas_inference: SpeciesClasInference instance for classification.
322
+ detection_threshold: Minimum confidence for detection filtering.
323
+
324
+ Example:
325
+ >>> pipeline = DetectAndClassify(
326
+ ... detector_path='md_v5a.0.0.pt',
327
+ ... classifier_path='species_model.pth',
328
+ ... label_names=['kangaroo', 'wallaby', 'wombat']
329
+ ... )
330
+ >>> results = pipeline.predict('wildlife_image.jpg')
331
+ """
332
+
333
+ def __init__(self,
334
+ detector_path: str,
335
+ classifier_path: str,
336
+ label_names: List[str],
337
+ classifier_base: str = 'tf_efficientnet_b5.ns_jft_in1k',
338
+ detection_threshold: float = 0.1,
339
+ clas_threshold: float = 0.5,
340
+ pred_topn: int = 1,
341
+ resize_size: int = 300,
342
+ force_cpu: bool = False,
343
+ skip_clas_errors: bool = True):
344
+ """
345
+ Initialize the detection and classification pipeline.
346
+
347
+ Args:
348
+ detector_path: Path to the MegaDetector model weights.
349
+ classifier_path: Path to the species classifier weights.
350
+ label_names: List of species class names.
351
+ classifier_base: Name of the base timm model architecture.
352
+ detection_threshold: Minimum confidence for animal detections.
353
+ clas_threshold: Minimum confidence for classification predictions.
354
+ pred_topn: Number of top classification predictions to return.
355
+ resize_size: Target image size for classification model input.
356
+ force_cpu: If True, use CPU even if CUDA is available.
357
+ skip_clas_errors: If True, skip classification errors instead of raising.
358
+ """
359
+ self.md_detector = run_detector.load_detector(str(detector_path),
360
+ force_cpu=force_cpu)
361
+ self.clas_inference = SpeciesClasInference(classifier_path=classifier_path,
362
+ classifier_base=classifier_base,
363
+ clas_threshold=clas_threshold,
364
+ label_names=label_names,
365
+ pred_topn=pred_topn,
366
+ resize_size=resize_size,
367
+ force_cpu=force_cpu,
368
+ skip_errors=skip_clas_errors)
369
+ self.detection_threshold = detection_threshold
370
+
371
+ def _validate_input(
372
+ self,
373
+ inp: Union[str, Image.Image, List[Union[str, Image.Image]]],
374
+ identifier: Union[str, List[str], None]
375
+ ) -> Tuple[List, List]:
376
+ """
377
+ Validate and normalize input images and identifiers.
378
+
379
+ Args:
380
+ inp: Single image or list of images (paths or PIL Images).
381
+ identifier: Optional identifier(s) for the images. If None,
382
+ uses file paths for string inputs or timestamps for PIL images.
383
+
384
+ Returns:
385
+ Tuple of (normalized_inputs, normalized_identifiers)
386
+
387
+ Raises:
388
+ AssertionError: If identifier list length doesn't match input list length.
389
+ """
390
+ if not isinstance(inp, (list, tuple)):
391
+ inp = [inp]
392
+ elif len(inp)==0:
393
+ return [],[]
394
+ if identifier is None:
395
+ if isinstance(inp[0], str):
396
+ identifier = inp
397
+ else:
398
+ # identifier based on date+time in human readable format, utc time
399
+ now = datetime.datetime.now(ZoneInfo("Australia/Perth"))
400
+ now_str = now.strftime("%Y%m%d_%H%M%S") + f"_{now.microsecond // 1000:03d}"
401
+ identifier = [now_str] if len(inp) == 1 else [f'{now_str}_{i+1}' for i in range(len(inp))]
402
+
403
+ elif not isinstance(identifier, (list, tuple)):
404
+ identifier = [identifier]
405
+
406
+ assert len(identifier) == len(inp), "Length of identifier list (containing e.g. image names) must match length of input list."
407
+ return inp, identifier
408
+
409
+ def predict(
410
+ self,
411
+ inp: Union[str, Image.Image, List[Union[str, Image.Image]]],
412
+ identifier: Union[str, List[str], None] = None,
413
+ clas_bs: int = 4
414
+ ) -> List[Tuple]:
415
+ """
416
+ Run detection and classification on input images.
417
+
418
+ Processes images through the MegaDetector to find animals, then
419
+ classifies each detected animal using the species classifier.
420
+
421
+ Args:
422
+ inp: Single image or list of images. Can be file paths (str)
423
+ or PIL Image objects.
424
+ identifier: Optional identifier(s) for tracking results back to
425
+ source images. If None, uses file paths or timestamps.
426
+ clas_bs: Batch size for classification inference.
427
+
428
+ Returns:
429
+ List of result tuples, one per detected animal. Each tuple contains:
430
+ (identifier, bbox, label1, prob1, label2, prob2, ...) where the
431
+ number of label/prob pairs depends on pred_topn and clas_threshold.
432
+ """
433
+ inp, identifier = self._validate_input(inp, identifier)
434
+ if len(inp) == 0:
435
+ return []
436
+
437
+ md_results=[]
438
+ for item,id in zip(inp, identifier):
439
+ img = item
440
+ if isinstance(item,str):
441
+ img = Image.open(item)
442
+ try:
443
+ md_result = self.md_detector.generate_detections_one_image(img,id,
444
+ detection_threshold=self.detection_threshold)
445
+ if not isinstance(item,str):
446
+ md_result['PIL'] = img
447
+ md_results.extend(format_md_detections(md_result))
448
+ finally:
449
+ if isinstance(item,str):
450
+ img.close()
451
+
452
+ return self.clas_inference.predict_batch(md_results, batch_size=clas_bs)
@@ -0,0 +1,82 @@
1
+ from PIL import Image, ImageOps
2
+ from typing import Sequence
3
+ import torch
4
+ import numpy as np
5
+
6
+ def crop_image(img: Image.Image, bbox_norm: Sequence[float], square_crop: bool = True) -> Image.Image:
7
+ """
8
+ Crop image based on normalized bounding box (MegaDetector format).
9
+
10
+ Args:
11
+ img: PIL Image to crop
12
+ bbox_norm: [x_min, y_min, width, height] all normalized 0-1
13
+ square_crop: Whether to make the crop square with padding
14
+
15
+ Returns:
16
+ Cropped PIL Image
17
+ """
18
+ img_w, img_h = img.size
19
+ xmin = int(bbox_norm[0] * img_w)
20
+ ymin = int(bbox_norm[1] * img_h)
21
+ box_w = int(bbox_norm[2] * img_w)
22
+ box_h = int(bbox_norm[3] * img_h)
23
+
24
+ if square_crop:
25
+ box_size = max(box_w, box_h)
26
+ xmin = max(0, min(xmin - int((box_size - box_w) / 2), img_w - box_w))
27
+ ymin = max(0, min(ymin - int((box_size - box_h) / 2), img_h - box_h))
28
+ box_w = min(img_w, box_size)
29
+ box_h = min(img_h, box_size)
30
+
31
+ if box_w == 0 or box_h == 0:
32
+ raise ValueError(f'Invalid crop dimensions (w={box_w}, h={box_h})')
33
+
34
+ crop = img.crop(box=[xmin, ymin, xmin + box_w, ymin + box_h])
35
+
36
+ if square_crop and (box_w != box_h):
37
+ crop = ImageOps.pad(crop, size=(box_size, box_size), color=0)
38
+
39
+ return crop
40
+
41
+
42
+ def pil_to_tensor(img: Image.Image,resize_size=300) -> torch.Tensor:
43
+ """
44
+ Convert PIL image to normalized tensor ready for model input.
45
+ - Use SMALLER ratio so crop fits within image
46
+ - Center crop to target aspect ratio
47
+ - Resize crop to target size
48
+
49
+ Args:
50
+ img: PIL Image (already cropped and in RGB)
51
+
52
+ Returns:
53
+ Tensor of shape (1, 3, H, W)
54
+ """
55
+ target_w = target_h = resize_size
56
+ w, h = img.size # PIL: (width, height)
57
+
58
+ # Fastai ResizeMethod.Crop uses SMALLER ratio
59
+ # This ensures crop fits within the image bounds
60
+ ratio_w = w / target_w
61
+ ratio_h = h / target_h
62
+ m = min(ratio_w, ratio_h)
63
+
64
+ # Crop size that when resized will give target size
65
+ cp_w = int(m * target_w)
66
+ cp_h = int(m * target_h)
67
+
68
+ # Center crop position (pcts = 0.5, 0.5 for validation)
69
+ left = (w - cp_w) // 2
70
+ top = (h - cp_h) // 2
71
+
72
+ # Crop
73
+ img = img.crop((left, top, left + cp_w, top + cp_h))
74
+
75
+ # Resize to target
76
+ img = img.resize((target_w, target_h), Image.BILINEAR)
77
+
78
+ # To tensor: HWC uint8 -> NCHW float32 [0, 1]
79
+ img_array = np.asarray(img, dtype=np.float32) / 255.0
80
+ img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0)
81
+
82
+ return img_tensor
@@ -0,0 +1,88 @@
1
+ Metadata-Version: 2.1
2
+ Name: awc-helpers
3
+ Version: 0.1.2
4
+ Summary: Australian Wildlife Conservancy's Wildlife detection and species classification inference tools
5
+ Author: Quan Tran
6
+ License: CC-BY-NC-SA-4.0
7
+ Project-URL: Homepage, https://github.com/Australian-Wildlife-Conservancy-AWC/awc_inference
8
+ Project-URL: Repository, https://github.com/Australian-Wildlife-Conservancy-AWC/awc_inference
9
+ Project-URL: Issues, https://github.com/Australian-Wildlife-Conservancy-AWC/awc_inference/issues
10
+ Keywords: wildlife,detection,classification,megadetector,camera-trap,ecology
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Science/Research
13
+ Classifier: Operating System :: OS Independent
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3.9
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Programming Language :: Python :: 3.13
20
+ Classifier: Topic :: Scientific/Engineering :: Image Recognition
21
+ Requires-Python: >=3.9
22
+ Description-Content-Type: text/markdown
23
+ License-File: LICENSE
24
+ Requires-Dist: torch==2.9.1
25
+ Requires-Dist: megadetector<=10.0.17
26
+ Requires-Dist: ultralytics<=8.4.7
27
+ Requires-Dist: timm<=1.0.24
28
+
29
+ # AWC Helpers
30
+
31
+ Wildlife detection and species classification inference tools combining MegaDetector with custom species classifiers.
32
+
33
+ ## Installation
34
+
35
+ ### 1. Install PyTorch
36
+
37
+ **Windows (with CUDA GPU):**
38
+ ```bash
39
+ pip install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128
40
+ ```
41
+
42
+ **Linux / Mac / CPU:**
43
+ ```bash
44
+ pip install torch==2.9.1
45
+ ```
46
+
47
+ ### 2. Install AWC Helpers
48
+
49
+ **From PyPI:**
50
+ ```bash
51
+ pip install awc-helpers
52
+ ```
53
+
54
+ **From GitHub:**
55
+ ```bash
56
+ pip install git+https://github.com/Australian-Wildlife-Conservancy-AWC/awc_inference.git
57
+ ```
58
+
59
+ ## Usage
60
+
61
+ ```python
62
+ from awc_helpers import DetectAndClassify
63
+
64
+ # Initialize the pipeline
65
+ pipeline = DetectAndClassify(
66
+ detector_path="path/to/megadetector.pt",
67
+ classifier_path="path/to/species_classifier.pth",
68
+ label_names=["species_a", "species_b", "species_c"],
69
+ detection_threshold=0.1,
70
+ clas_threshold=0.5,
71
+ )
72
+
73
+ # Run inference on image paths
74
+ results = pipeline.predict(
75
+ inp=["image1.jpg", "image2.jpg"],
76
+ clas_bs=4
77
+ )
78
+
79
+ # Results format: [(identifier, bbox, label, confidence), ...]
80
+ for result in results:
81
+ print(result)
82
+ ```
83
+
84
+ ## License
85
+
86
+ This project is licensed under the [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License](LICENSE) (CC BY-NC-SA 4.0).
87
+
88
+ **Non-commercial use only. Derivative works must use the same license.**
@@ -0,0 +1,12 @@
1
+ LICENSE
2
+ MANIFEST.in
3
+ README.md
4
+ pyproject.toml
5
+ awc_helpers/__init__.py
6
+ awc_helpers/awc_inference.py
7
+ awc_helpers/math_utils.py
8
+ awc_helpers.egg-info/PKG-INFO
9
+ awc_helpers.egg-info/SOURCES.txt
10
+ awc_helpers.egg-info/dependency_links.txt
11
+ awc_helpers.egg-info/requires.txt
12
+ awc_helpers.egg-info/top_level.txt
@@ -0,0 +1,4 @@
1
+ torch==2.9.1
2
+ megadetector<=10.0.17
3
+ ultralytics<=8.4.7
4
+ timm<=1.0.24
@@ -0,0 +1 @@
1
+ awc_helpers
@@ -0,0 +1,38 @@
1
+ [build-system]
2
+ requires = ["setuptools>=42,<69", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "awc_helpers"
7
+ version = "0.1.2"
8
+ description = "Australian Wildlife Conservancy's Wildlife detection and species classification inference tools"
9
+ readme = "README.md"
10
+ license = {text = "CC-BY-NC-SA-4.0"}
11
+ requires-python = ">=3.9"
12
+ authors = [
13
+ { name = "Quan Tran" }
14
+ ]
15
+ keywords = ["wildlife", "detection", "classification", "megadetector", "camera-trap", "ecology"]
16
+ classifiers = [
17
+ "Development Status :: 3 - Alpha",
18
+ "Intended Audience :: Science/Research",
19
+ "Operating System :: OS Independent",
20
+ "Programming Language :: Python :: 3",
21
+ "Programming Language :: Python :: 3.9",
22
+ "Programming Language :: Python :: 3.10",
23
+ "Programming Language :: Python :: 3.11",
24
+ "Programming Language :: Python :: 3.12",
25
+ "Programming Language :: Python :: 3.13",
26
+ "Topic :: Scientific/Engineering :: Image Recognition",
27
+ ]
28
+ dependencies = [
29
+ "torch==2.9.1",
30
+ "megadetector<=10.0.17",
31
+ "ultralytics<=8.4.7",
32
+ "timm<=1.0.24"
33
+ ]
34
+
35
+ [project.urls]
36
+ Homepage = "https://github.com/Australian-Wildlife-Conservancy-AWC/awc_inference"
37
+ Repository = "https://github.com/Australian-Wildlife-Conservancy-AWC/awc_inference"
38
+ Issues = "https://github.com/Australian-Wildlife-Conservancy-AWC/awc_inference/issues"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+