birder 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. birder/__init__.py +2 -0
  2. birder/common/fs_ops.py +81 -1
  3. birder/common/training_cli.py +6 -1
  4. birder/common/training_utils.py +4 -0
  5. birder/data/collators/detection.py +3 -1
  6. birder/datahub/_lib.py +15 -6
  7. birder/datahub/evaluation.py +591 -0
  8. birder/eval/__init__.py +0 -0
  9. birder/eval/__main__.py +74 -0
  10. birder/eval/_embeddings.py +50 -0
  11. birder/eval/adversarial.py +315 -0
  12. birder/eval/benchmarks/__init__.py +0 -0
  13. birder/eval/benchmarks/awa2.py +357 -0
  14. birder/eval/benchmarks/bioscan5m.py +198 -0
  15. birder/eval/benchmarks/fishnet.py +318 -0
  16. birder/eval/benchmarks/flowers102.py +210 -0
  17. birder/eval/benchmarks/fungiclef.py +261 -0
  18. birder/eval/benchmarks/nabirds.py +202 -0
  19. birder/eval/benchmarks/newt.py +262 -0
  20. birder/eval/benchmarks/plankton.py +255 -0
  21. birder/eval/benchmarks/plantdoc.py +259 -0
  22. birder/eval/benchmarks/plantnet.py +252 -0
  23. birder/eval/classification.py +235 -0
  24. birder/eval/methods/__init__.py +0 -0
  25. birder/eval/methods/ami.py +78 -0
  26. birder/eval/methods/knn.py +71 -0
  27. birder/eval/methods/linear.py +152 -0
  28. birder/eval/methods/mlp.py +178 -0
  29. birder/eval/methods/simpleshot.py +100 -0
  30. birder/eval/methods/svm.py +92 -0
  31. birder/inference/classification.py +23 -2
  32. birder/inference/detection.py +35 -15
  33. birder/net/cswin_transformer.py +2 -1
  34. birder/net/detection/base.py +41 -18
  35. birder/net/detection/deformable_detr.py +63 -39
  36. birder/net/detection/detr.py +23 -20
  37. birder/net/detection/efficientdet.py +42 -25
  38. birder/net/detection/faster_rcnn.py +53 -21
  39. birder/net/detection/fcos.py +42 -23
  40. birder/net/detection/lw_detr.py +58 -35
  41. birder/net/detection/plain_detr.py +54 -43
  42. birder/net/detection/retinanet.py +46 -34
  43. birder/net/detection/rt_detr_v1.py +41 -38
  44. birder/net/detection/rt_detr_v2.py +50 -40
  45. birder/net/detection/ssd.py +47 -31
  46. birder/net/detection/yolo_v2.py +33 -18
  47. birder/net/detection/yolo_v3.py +35 -33
  48. birder/net/detection/yolo_v4.py +35 -20
  49. birder/net/detection/yolo_v4_tiny.py +1 -2
  50. birder/net/hiera.py +44 -67
  51. birder/net/maxvit.py +2 -2
  52. birder/net/mim/fcmae.py +2 -2
  53. birder/net/mim/mae_hiera.py +9 -16
  54. birder/net/nextvit.py +4 -4
  55. birder/net/rope_deit3.py +1 -1
  56. birder/net/rope_flexivit.py +1 -1
  57. birder/net/rope_vit.py +1 -1
  58. birder/net/squeezenet.py +1 -1
  59. birder/net/ssl/capi.py +32 -25
  60. birder/net/ssl/dino_v2.py +12 -15
  61. birder/net/ssl/franca.py +26 -19
  62. birder/net/van.py +2 -2
  63. birder/net/xcit.py +1 -1
  64. birder/ops/msda.py +46 -16
  65. birder/scripts/benchmark.py +35 -8
  66. birder/scripts/predict.py +14 -1
  67. birder/scripts/predict_detection.py +7 -1
  68. birder/scripts/train.py +15 -3
  69. birder/scripts/train_detection.py +16 -6
  70. birder/scripts/train_franca.py +10 -2
  71. birder/scripts/train_kd.py +16 -3
  72. birder/tools/adversarial.py +5 -0
  73. birder/tools/convert_model.py +101 -43
  74. birder/tools/quantize_model.py +33 -16
  75. birder/version.py +1 -1
  76. {birder-0.4.2.dist-info → birder-0.4.4.dist-info}/METADATA +16 -9
  77. {birder-0.4.2.dist-info → birder-0.4.4.dist-info}/RECORD +81 -58
  78. birder/scripts/evaluate.py +0 -176
  79. {birder-0.4.2.dist-info → birder-0.4.4.dist-info}/WHEEL +0 -0
  80. {birder-0.4.2.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
  81. {birder-0.4.2.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
  82. {birder-0.4.2.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
birder/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from birder.common.fs_ops import load_model_with_cfg
2
2
  from birder.common.fs_ops import load_pretrained_model
3
+ from birder.common.fs_ops import load_pretrained_model_and_transform
3
4
  from birder.common.lib import get_channels_from_signature
4
5
  from birder.common.lib import get_size_from_signature
5
6
  from birder.data.transforms.classification import inference_preset as classification_transform
@@ -17,5 +18,6 @@ __all__ = [
17
18
  "list_pretrained_models",
18
19
  "load_model_with_cfg",
19
20
  "load_pretrained_model",
21
+ "load_pretrained_model_and_transform",
20
22
  "__version__",
21
23
  ]
birder/common/fs_ops.py CHANGED
@@ -2,6 +2,7 @@ import json
2
2
  import logging
3
3
  import os
4
4
  import re
5
+ from collections.abc import Callable
5
6
  from collections.abc import Iterator
6
7
  from pathlib import Path
7
8
  from typing import Any
@@ -24,6 +25,8 @@ from birder.common.lib import get_network_name
24
25
  from birder.common.lib import get_pretrained_model_url
25
26
  from birder.conf import settings
26
27
  from birder.data.transforms.classification import RGBType
28
+ from birder.data.transforms.classification import inference_preset
29
+ from birder.data.transforms.detection import InferenceTransform
27
30
  from birder.model_registry import Task
28
31
  from birder.model_registry import registry
29
32
  from birder.model_registry.manifest import FileFormatType
@@ -801,7 +804,8 @@ def load_detection_model(
801
804
  for param in net.parameters():
802
805
  param.requires_grad_(False)
803
806
 
804
- net.eval()
807
+ if pt2 is False: # NOTE: Remove when GraphModule add support for 'eval'
808
+ net.eval()
805
809
 
806
810
  if len(backbone_loaded_config) == 0:
807
811
  backbone_custom_config = None
@@ -918,6 +922,82 @@ def load_pretrained_model(
918
922
  raise ValueError(f"Unknown model type: {model_metadata['task']}")
919
923
 
920
924
 
925
+ def load_pretrained_model_and_transform(
926
+ weights: str,
927
+ *,
928
+ dst: Optional[str | Path] = None,
929
+ file_format: FileFormatType = "pt",
930
+ inference: bool = True,
931
+ device: Optional[torch.device] = None,
932
+ dtype: Optional[torch.dtype] = None,
933
+ custom_config: Optional[dict[str, Any]] = None,
934
+ progress_bar: bool = True,
935
+ classification_kwargs: Optional[dict[str, Any]] = None,
936
+ detection_kwargs: Optional[dict[str, Any]] = None,
937
+ ) -> tuple[BaseNet | DetectionBaseNet, ModelInfo | DetectionModelInfo, Callable[..., torch.Tensor]]:
938
+ """
939
+ Loads a pre-trained model and builds the matching inference transform
940
+
941
+ This is a convenience helper for the common inference path where the model and
942
+ its default preprocessing are needed together. Classification models use
943
+ inference_preset, detection models use InferenceTransform.
944
+
945
+ Parameters
946
+ ----------
947
+ weights
948
+ Name of the pre-trained weights to load from the model registry.
949
+ dst
950
+ Destination path where the model weights will be downloaded or loaded from.
951
+ file_format
952
+ Model format (e.g. pt, pt2, safetensors, etc.)
953
+ inference
954
+ Flag to prepare the model for inference mode.
955
+ device
956
+ The device to load the model on (cpu/cuda).
957
+ dtype
958
+ Data type for model parameters and computations (e.g., torch.float32, torch.float16).
959
+ custom_config
960
+ Additional model configuration that overrides or extends the predefined configuration.
961
+ progress_bar
962
+ Whether to display a progress bar during file download.
963
+ classification_kwargs
964
+ Optional keyword arguments forwarded to inference_preset.
965
+ detection_kwargs
966
+ Optional keyword arguments forwarded to InferenceTransform. If dynamic_size is
967
+ not provided it defaults to the model signature value.
968
+
969
+ Returns
970
+ -------
971
+ A tuple containing three elements:
972
+ - A PyTorch module (neural network model) loaded with pre-trained weights.
973
+ - Model info containing class mappings, signature, and RGB stats.
974
+ - An inference transform matching the model task.
975
+ """
976
+
977
+ net, model_info = load_pretrained_model(
978
+ weights,
979
+ dst=dst,
980
+ file_format=file_format,
981
+ inference=inference,
982
+ device=device,
983
+ dtype=dtype,
984
+ custom_config=custom_config,
985
+ progress_bar=progress_bar,
986
+ )
987
+
988
+ size = lib.get_size_from_signature(model_info.signature)
989
+ transform: Callable[..., torch.Tensor]
990
+ if isinstance(model_info, DetectionModelInfo):
991
+ detection_args = {} if detection_kwargs is None else dict(detection_kwargs)
992
+ detection_args.setdefault("dynamic_size", model_info.signature["dynamic"])
993
+ transform = InferenceTransform(size, model_info.rgb_stats, **detection_args)
994
+ else:
995
+ classification_args = {} if classification_kwargs is None else dict(classification_kwargs)
996
+ transform = inference_preset(size, model_info.rgb_stats, **classification_args)
997
+
998
+ return (net, model_info, transform)
999
+
1000
+
921
1001
  def load_model_with_cfg(
922
1002
  cfg: dict[str, Any] | str | Path, weights_path: Optional[str | Path]
923
1003
  ) -> tuple[torch.nn.Module, dict[str, Any]]:
@@ -485,8 +485,13 @@ def add_dataloader_args(
485
485
  )
486
486
 
487
487
 
488
- def add_precision_args(parser: argparse.ArgumentParser) -> None:
488
+ def add_precision_args(parser: argparse.ArgumentParser, channels_last: bool = False) -> None:
489
489
  group = parser.add_argument_group("Precision parameters")
490
+ if channels_last is True:
491
+ group.add_argument(
492
+ "--channels-last", default=False, action="store_true", help="use channels-last memory format"
493
+ )
494
+
490
495
  group.add_argument(
491
496
  "--model-dtype",
492
497
  type=str,
@@ -1165,12 +1165,16 @@ def init_training(
1165
1165
  device_id = torch.cuda.current_device()
1166
1166
 
1167
1167
  if args.use_deterministic_algorithms is True:
1168
+ log.debug("Turning on deterministic algorithms")
1168
1169
  torch.backends.cudnn.benchmark = False
1169
1170
  torch.use_deterministic_algorithms(True)
1170
1171
  elif cudnn_dynamic_size is True:
1171
1172
  # Dynamic sizes: avoid per-size algorithm selection overhead.
1173
+ log.debug("Turning off cudnn")
1172
1174
  torch.backends.cudnn.enabled = False
1175
+ torch.backends.cudnn.benchmark = False
1173
1176
  else:
1177
+ log.debug("Turning on cudnn")
1174
1178
  torch.backends.cudnn.enabled = True
1175
1179
  torch.backends.cudnn.benchmark = True
1176
1180
 
@@ -15,7 +15,9 @@ def collate_fn(batch: list[tuple[Any, ...]]) -> tuple[Any, ...]:
15
15
  return tuple(zip(*batch))
16
16
 
17
17
 
18
- def batch_images(images: list[torch.Tensor], size_divisible: int) -> tuple[torch.Tensor, torch.Tensor, list[list[int]]]:
18
+ def batch_images(
19
+ images: list[torch.Tensor], size_divisible: int
20
+ ) -> tuple[torch.Tensor, torch.Tensor, list[tuple[int, int]]]:
19
21
  """
20
22
  Batch list of image tensors of different sizes into a single batch.
21
23
  Pad with zeros all images to the shape of the largest image in the list.
birder/datahub/_lib.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import logging
2
2
  import tarfile
3
+ import zipfile
3
4
  from pathlib import Path
4
5
 
5
6
  from birder.common import cli
@@ -26,9 +27,17 @@ def download_url(url: str, target: str | Path, sha256: str, progress_bar: bool =
26
27
 
27
28
  def extract_archive(from_path: str | Path, to_path: str | Path) -> None:
28
29
  logger.info(f"Extracting {from_path} to {to_path}")
29
- with tarfile.open(from_path, "r") as tar:
30
- if hasattr(tarfile, "data_filter") is True:
31
- tar.extractall(to_path, filter="data")
32
- else:
33
- # NOTE: Remove once minimum Python version is 3.12 or above
34
- tar.extractall(to_path) # nosec # tarfile_unsafe_members
30
+ if isinstance(from_path, str):
31
+ from_path = Path(from_path)
32
+
33
+ if from_path.suffix == ".zip":
34
+ with zipfile.ZipFile(from_path, "r") as zf:
35
+ zf.extractall(to_path) # nosec # tarfile_unsafe_members
36
+
37
+ else:
38
+ with tarfile.open(from_path, "r") as tar:
39
+ if hasattr(tarfile, "data_filter") is True:
40
+ tar.extractall(to_path, filter="data")
41
+ else:
42
+ # NOTE: Remove once minimum Python version is 3.12 or above
43
+ tar.extractall(to_path) # nosec # tarfile_unsafe_members