birder 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/__init__.py +2 -0
- birder/common/fs_ops.py +81 -1
- birder/common/training_cli.py +6 -1
- birder/common/training_utils.py +4 -0
- birder/data/collators/detection.py +3 -1
- birder/datahub/_lib.py +15 -6
- birder/datahub/evaluation.py +591 -0
- birder/eval/__init__.py +0 -0
- birder/eval/__main__.py +74 -0
- birder/eval/_embeddings.py +50 -0
- birder/eval/adversarial.py +315 -0
- birder/eval/benchmarks/__init__.py +0 -0
- birder/eval/benchmarks/awa2.py +357 -0
- birder/eval/benchmarks/bioscan5m.py +198 -0
- birder/eval/benchmarks/fishnet.py +318 -0
- birder/eval/benchmarks/flowers102.py +210 -0
- birder/eval/benchmarks/fungiclef.py +261 -0
- birder/eval/benchmarks/nabirds.py +202 -0
- birder/eval/benchmarks/newt.py +262 -0
- birder/eval/benchmarks/plankton.py +255 -0
- birder/eval/benchmarks/plantdoc.py +259 -0
- birder/eval/benchmarks/plantnet.py +252 -0
- birder/eval/classification.py +235 -0
- birder/eval/methods/__init__.py +0 -0
- birder/eval/methods/ami.py +78 -0
- birder/eval/methods/knn.py +71 -0
- birder/eval/methods/linear.py +152 -0
- birder/eval/methods/mlp.py +178 -0
- birder/eval/methods/simpleshot.py +100 -0
- birder/eval/methods/svm.py +92 -0
- birder/inference/classification.py +23 -2
- birder/inference/detection.py +35 -15
- birder/net/cswin_transformer.py +2 -1
- birder/net/detection/base.py +41 -18
- birder/net/detection/deformable_detr.py +63 -39
- birder/net/detection/detr.py +23 -20
- birder/net/detection/efficientdet.py +42 -25
- birder/net/detection/faster_rcnn.py +53 -21
- birder/net/detection/fcos.py +42 -23
- birder/net/detection/lw_detr.py +58 -35
- birder/net/detection/plain_detr.py +54 -43
- birder/net/detection/retinanet.py +46 -34
- birder/net/detection/rt_detr_v1.py +41 -38
- birder/net/detection/rt_detr_v2.py +50 -40
- birder/net/detection/ssd.py +47 -31
- birder/net/detection/yolo_v2.py +33 -18
- birder/net/detection/yolo_v3.py +35 -33
- birder/net/detection/yolo_v4.py +35 -20
- birder/net/detection/yolo_v4_tiny.py +1 -2
- birder/net/hiera.py +44 -67
- birder/net/maxvit.py +2 -2
- birder/net/mim/fcmae.py +2 -2
- birder/net/mim/mae_hiera.py +9 -16
- birder/net/nextvit.py +4 -4
- birder/net/rope_deit3.py +1 -1
- birder/net/rope_flexivit.py +1 -1
- birder/net/rope_vit.py +1 -1
- birder/net/squeezenet.py +1 -1
- birder/net/ssl/capi.py +32 -25
- birder/net/ssl/dino_v2.py +12 -15
- birder/net/ssl/franca.py +26 -19
- birder/net/van.py +2 -2
- birder/net/xcit.py +1 -1
- birder/ops/msda.py +46 -16
- birder/scripts/benchmark.py +35 -8
- birder/scripts/predict.py +14 -1
- birder/scripts/predict_detection.py +7 -1
- birder/scripts/train.py +15 -3
- birder/scripts/train_detection.py +16 -6
- birder/scripts/train_franca.py +10 -2
- birder/scripts/train_kd.py +16 -3
- birder/tools/adversarial.py +5 -0
- birder/tools/convert_model.py +101 -43
- birder/tools/quantize_model.py +33 -16
- birder/version.py +1 -1
- {birder-0.4.2.dist-info → birder-0.4.4.dist-info}/METADATA +16 -9
- {birder-0.4.2.dist-info → birder-0.4.4.dist-info}/RECORD +81 -58
- birder/scripts/evaluate.py +0 -176
- {birder-0.4.2.dist-info → birder-0.4.4.dist-info}/WHEEL +0 -0
- {birder-0.4.2.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
- {birder-0.4.2.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.2.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
birder/__init__.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from birder.common.fs_ops import load_model_with_cfg
|
|
2
2
|
from birder.common.fs_ops import load_pretrained_model
|
|
3
|
+
from birder.common.fs_ops import load_pretrained_model_and_transform
|
|
3
4
|
from birder.common.lib import get_channels_from_signature
|
|
4
5
|
from birder.common.lib import get_size_from_signature
|
|
5
6
|
from birder.data.transforms.classification import inference_preset as classification_transform
|
|
@@ -17,5 +18,6 @@ __all__ = [
|
|
|
17
18
|
"list_pretrained_models",
|
|
18
19
|
"load_model_with_cfg",
|
|
19
20
|
"load_pretrained_model",
|
|
21
|
+
"load_pretrained_model_and_transform",
|
|
20
22
|
"__version__",
|
|
21
23
|
]
|
birder/common/fs_ops.py
CHANGED
|
@@ -2,6 +2,7 @@ import json
|
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
4
|
import re
|
|
5
|
+
from collections.abc import Callable
|
|
5
6
|
from collections.abc import Iterator
|
|
6
7
|
from pathlib import Path
|
|
7
8
|
from typing import Any
|
|
@@ -24,6 +25,8 @@ from birder.common.lib import get_network_name
|
|
|
24
25
|
from birder.common.lib import get_pretrained_model_url
|
|
25
26
|
from birder.conf import settings
|
|
26
27
|
from birder.data.transforms.classification import RGBType
|
|
28
|
+
from birder.data.transforms.classification import inference_preset
|
|
29
|
+
from birder.data.transforms.detection import InferenceTransform
|
|
27
30
|
from birder.model_registry import Task
|
|
28
31
|
from birder.model_registry import registry
|
|
29
32
|
from birder.model_registry.manifest import FileFormatType
|
|
@@ -801,7 +804,8 @@ def load_detection_model(
|
|
|
801
804
|
for param in net.parameters():
|
|
802
805
|
param.requires_grad_(False)
|
|
803
806
|
|
|
804
|
-
|
|
807
|
+
if pt2 is False: # NOTE: Remove when GraphModule add support for 'eval'
|
|
808
|
+
net.eval()
|
|
805
809
|
|
|
806
810
|
if len(backbone_loaded_config) == 0:
|
|
807
811
|
backbone_custom_config = None
|
|
@@ -918,6 +922,82 @@ def load_pretrained_model(
|
|
|
918
922
|
raise ValueError(f"Unknown model type: {model_metadata['task']}")
|
|
919
923
|
|
|
920
924
|
|
|
925
|
+
def load_pretrained_model_and_transform(
|
|
926
|
+
weights: str,
|
|
927
|
+
*,
|
|
928
|
+
dst: Optional[str | Path] = None,
|
|
929
|
+
file_format: FileFormatType = "pt",
|
|
930
|
+
inference: bool = True,
|
|
931
|
+
device: Optional[torch.device] = None,
|
|
932
|
+
dtype: Optional[torch.dtype] = None,
|
|
933
|
+
custom_config: Optional[dict[str, Any]] = None,
|
|
934
|
+
progress_bar: bool = True,
|
|
935
|
+
classification_kwargs: Optional[dict[str, Any]] = None,
|
|
936
|
+
detection_kwargs: Optional[dict[str, Any]] = None,
|
|
937
|
+
) -> tuple[BaseNet | DetectionBaseNet, ModelInfo | DetectionModelInfo, Callable[..., torch.Tensor]]:
|
|
938
|
+
"""
|
|
939
|
+
Loads a pre-trained model and builds the matching inference transform
|
|
940
|
+
|
|
941
|
+
This is a convenience helper for the common inference path where the model and
|
|
942
|
+
its default preprocessing are needed together. Classification models use
|
|
943
|
+
inference_preset, detection models use InferenceTransform.
|
|
944
|
+
|
|
945
|
+
Parameters
|
|
946
|
+
----------
|
|
947
|
+
weights
|
|
948
|
+
Name of the pre-trained weights to load from the model registry.
|
|
949
|
+
dst
|
|
950
|
+
Destination path where the model weights will be downloaded or loaded from.
|
|
951
|
+
file_format
|
|
952
|
+
Model format (e.g. pt, pt2, safetensors, etc.)
|
|
953
|
+
inference
|
|
954
|
+
Flag to prepare the model for inference mode.
|
|
955
|
+
device
|
|
956
|
+
The device to load the model on (cpu/cuda).
|
|
957
|
+
dtype
|
|
958
|
+
Data type for model parameters and computations (e.g., torch.float32, torch.float16).
|
|
959
|
+
custom_config
|
|
960
|
+
Additional model configuration that overrides or extends the predefined configuration.
|
|
961
|
+
progress_bar
|
|
962
|
+
Whether to display a progress bar during file download.
|
|
963
|
+
classification_kwargs
|
|
964
|
+
Optional keyword arguments forwarded to inference_preset.
|
|
965
|
+
detection_kwargs
|
|
966
|
+
Optional keyword arguments forwarded to InferenceTransform. If dynamic_size is
|
|
967
|
+
not provided it defaults to the model signature value.
|
|
968
|
+
|
|
969
|
+
Returns
|
|
970
|
+
-------
|
|
971
|
+
A tuple containing three elements:
|
|
972
|
+
- A PyTorch module (neural network model) loaded with pre-trained weights.
|
|
973
|
+
- Model info containing class mappings, signature, and RGB stats.
|
|
974
|
+
- An inference transform matching the model task.
|
|
975
|
+
"""
|
|
976
|
+
|
|
977
|
+
net, model_info = load_pretrained_model(
|
|
978
|
+
weights,
|
|
979
|
+
dst=dst,
|
|
980
|
+
file_format=file_format,
|
|
981
|
+
inference=inference,
|
|
982
|
+
device=device,
|
|
983
|
+
dtype=dtype,
|
|
984
|
+
custom_config=custom_config,
|
|
985
|
+
progress_bar=progress_bar,
|
|
986
|
+
)
|
|
987
|
+
|
|
988
|
+
size = lib.get_size_from_signature(model_info.signature)
|
|
989
|
+
transform: Callable[..., torch.Tensor]
|
|
990
|
+
if isinstance(model_info, DetectionModelInfo):
|
|
991
|
+
detection_args = {} if detection_kwargs is None else dict(detection_kwargs)
|
|
992
|
+
detection_args.setdefault("dynamic_size", model_info.signature["dynamic"])
|
|
993
|
+
transform = InferenceTransform(size, model_info.rgb_stats, **detection_args)
|
|
994
|
+
else:
|
|
995
|
+
classification_args = {} if classification_kwargs is None else dict(classification_kwargs)
|
|
996
|
+
transform = inference_preset(size, model_info.rgb_stats, **classification_args)
|
|
997
|
+
|
|
998
|
+
return (net, model_info, transform)
|
|
999
|
+
|
|
1000
|
+
|
|
921
1001
|
def load_model_with_cfg(
|
|
922
1002
|
cfg: dict[str, Any] | str | Path, weights_path: Optional[str | Path]
|
|
923
1003
|
) -> tuple[torch.nn.Module, dict[str, Any]]:
|
birder/common/training_cli.py
CHANGED
|
@@ -485,8 +485,13 @@ def add_dataloader_args(
|
|
|
485
485
|
)
|
|
486
486
|
|
|
487
487
|
|
|
488
|
-
def add_precision_args(parser: argparse.ArgumentParser) -> None:
|
|
488
|
+
def add_precision_args(parser: argparse.ArgumentParser, channels_last: bool = False) -> None:
|
|
489
489
|
group = parser.add_argument_group("Precision parameters")
|
|
490
|
+
if channels_last is True:
|
|
491
|
+
group.add_argument(
|
|
492
|
+
"--channels-last", default=False, action="store_true", help="use channels-last memory format"
|
|
493
|
+
)
|
|
494
|
+
|
|
490
495
|
group.add_argument(
|
|
491
496
|
"--model-dtype",
|
|
492
497
|
type=str,
|
birder/common/training_utils.py
CHANGED
|
@@ -1165,12 +1165,16 @@ def init_training(
|
|
|
1165
1165
|
device_id = torch.cuda.current_device()
|
|
1166
1166
|
|
|
1167
1167
|
if args.use_deterministic_algorithms is True:
|
|
1168
|
+
log.debug("Turning on deterministic algorithms")
|
|
1168
1169
|
torch.backends.cudnn.benchmark = False
|
|
1169
1170
|
torch.use_deterministic_algorithms(True)
|
|
1170
1171
|
elif cudnn_dynamic_size is True:
|
|
1171
1172
|
# Dynamic sizes: avoid per-size algorithm selection overhead.
|
|
1173
|
+
log.debug("Turning off cudnn")
|
|
1172
1174
|
torch.backends.cudnn.enabled = False
|
|
1175
|
+
torch.backends.cudnn.benchmark = False
|
|
1173
1176
|
else:
|
|
1177
|
+
log.debug("Turning on cudnn")
|
|
1174
1178
|
torch.backends.cudnn.enabled = True
|
|
1175
1179
|
torch.backends.cudnn.benchmark = True
|
|
1176
1180
|
|
|
@@ -15,7 +15,9 @@ def collate_fn(batch: list[tuple[Any, ...]]) -> tuple[Any, ...]:
|
|
|
15
15
|
return tuple(zip(*batch))
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
def batch_images(
|
|
18
|
+
def batch_images(
|
|
19
|
+
images: list[torch.Tensor], size_divisible: int
|
|
20
|
+
) -> tuple[torch.Tensor, torch.Tensor, list[tuple[int, int]]]:
|
|
19
21
|
"""
|
|
20
22
|
Batch list of image tensors of different sizes into a single batch.
|
|
21
23
|
Pad with zeros all images to the shape of the largest image in the list.
|
birder/datahub/_lib.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import tarfile
|
|
3
|
+
import zipfile
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
|
|
5
6
|
from birder.common import cli
|
|
@@ -26,9 +27,17 @@ def download_url(url: str, target: str | Path, sha256: str, progress_bar: bool =
|
|
|
26
27
|
|
|
27
28
|
def extract_archive(from_path: str | Path, to_path: str | Path) -> None:
|
|
28
29
|
logger.info(f"Extracting {from_path} to {to_path}")
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
30
|
+
if isinstance(from_path, str):
|
|
31
|
+
from_path = Path(from_path)
|
|
32
|
+
|
|
33
|
+
if from_path.suffix == ".zip":
|
|
34
|
+
with zipfile.ZipFile(from_path, "r") as zf:
|
|
35
|
+
zf.extractall(to_path) # nosec # tarfile_unsafe_members
|
|
36
|
+
|
|
37
|
+
else:
|
|
38
|
+
with tarfile.open(from_path, "r") as tar:
|
|
39
|
+
if hasattr(tarfile, "data_filter") is True:
|
|
40
|
+
tar.extractall(to_path, filter="data")
|
|
41
|
+
else:
|
|
42
|
+
# NOTE: Remove once minimum Python version is 3.12 or above
|
|
43
|
+
tar.extractall(to_path) # nosec # tarfile_unsafe_members
|