libviper 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,165 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
163
+
164
+ # env file
165
+ .env
libviper-0.1.0/LICENSE ADDED
@@ -0,0 +1,28 @@
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2025, Martin Kvisvik Larsen
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its
16
+ contributors may be used to endorse or promote products derived from
17
+ this software without specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@@ -0,0 +1,165 @@
1
+ Metadata-Version: 2.4
2
+ Name: libviper
3
+ Version: 0.1.0
4
+ Summary: A package for visual place recognition (VPR) models.
5
+ Author-email: Martin Kvisvik Larsen <martin.kvisvik.larsen@hotmail.com>
6
+ License-File: LICENSE
7
+ Requires-Python: >=3.12
8
+ Requires-Dist: einops>=0.8.1
9
+ Requires-Dist: fast-pytorch-kmeans>=0.2.2
10
+ Requires-Dist: imageio>=2.37.0
11
+ Requires-Dist: loguru>=0.7.3
12
+ Requires-Dist: msgspec>=0.19.0
13
+ Requires-Dist: pytest>=8.3.4
14
+ Requires-Dist: python-dotenv>=1.0.1
15
+ Requires-Dist: pytorch-lightning>=2.6.0
16
+ Requires-Dist: ruff>=0.8.3
17
+ Requires-Dist: scipy>=1.16.0
18
+ Requires-Dist: torch>=2.7.0
19
+ Requires-Dist: torchvision>=0.24.0
20
+ Requires-Dist: tqdm>=4.67.1
21
+ Requires-Dist: xformers>=0.0.30
22
+ Description-Content-Type: text/markdown
23
+
24
+ # Viper: A Common Image Embedder Interface for Visual Place Recognition
25
+
26
+ ![ci](https://github.com/markvilar/viper/actions/workflows/ubuntu.yml/badge.svg)
27
+
28
+ This Python package provides a **unified** image embedder interface for visual place recognition (VPR), along with wrapper implementations of several state-of-the-art VPR models so they all expose the same API.
29
+ It also includes a lightweight registry mechanism that lets you register custom embedders and retrieve them by string key.
30
+
31
+ ## Features
32
+
33
+ - Common `ImageEmbedder` protocol for VPR models (name, vector size, device, call semantics).
34
+ - Eight wrapper models adapting popular VPR architectures to this interface:
35
+ - AnyLoc
36
+ - CliqueMining
37
+ - CosPlace
38
+ - EigenPlaces
39
+ - MegaLoc
40
+ - MixVPR
41
+ - NetVLAD
42
+ - SALAD
43
+ - Simple registry (`register_embedder_factory` / `get_embedder_factory`) for instantiating embedders by key.
44
+
45
+ ## Installation
46
+
47
+ The package is configured as a standard Python project via `pyproject.toml` and adds support for the `uv` package manager.
48
+
49
+ You can install it in editable mode for development:
50
+
51
+ ```bash
52
+ uv sync
53
+ ```
54
+
55
+ **Disclaimer:** AnyLoc, CliqueMining, MegaLoc, and SALAD require CUDA to run.
56
+
57
+ ## Usage
58
+
59
+ ### Loading a built-in embedder
60
+
61
+ The recommended way to construct models is through the embedder registry.
62
+
63
+ ```python
64
+ import viper
65
+
66
+ factory: viper.ImageEmbedderFactory = viper.get_embedder_factory("salad") # or "mixvpr", "netvlad", "eigenplaces", ...
67
+ embedder: viper.ImageEmbedder = factory()
68
+
69
+ print(embedder.name) # "salad"
70
+ print(embedder.vector_size) # 8448
71
+ print(embedder.embedder_parameters) # dict of parameters
72
+ print(embedder.device) # "cpu" or "cuda"
73
+ ```
74
+
75
+ All embedders implement the `ImageEmbedder` protocol, which exposes:
76
+
77
+ - `name: str`
78
+ - `vector_size: int` (embedding dimension)
79
+ - `embedder_parameters: dict[str, Any]` (model-specific metadata such as backbone, descriptor size, etc.)
80
+ - `device: str` (e.g. `"cpu"` or `"cuda:0"`)
81
+ - `__call__(images: torch.Tensor) -> torch.Tensor` (batched embedding, `B x C x H x W -> B x E`)
82
+
83
+
84
+ ### Embedding a batch of images
85
+
86
+ All wrappers expect a batch of images as a float tensor of shape `B x C x H x W` with values in `[0.0, 1.0]` (for NetVLAD this is explicitly asserted).
87
+
88
+ ```python
89
+ import torch
90
+ import viper
91
+
92
+ factory: viper.ImageEmbedderFactory = viper.get_embedder_factory("netvlad")
93
+ embedder: viper.ImageEmbedder = factory()
94
+ images: torch.Tensor = torch.rand(8, 3, 480, 640) # example batch, normalized to [0, 1]
95
+ embeddings: torch.Tensor = embedder(images) # shape: (8, embedder.vectorsize)
96
+ ```
97
+
98
+ Wrappers handle grayscale input by converting 1-channel batches to 3-channel RGB internally.
99
+ Some models also resize images so that height/width are multiples of 14, matching DINOv2 backbone constraints.
100
+
101
+ ### Registering a custom model
102
+
103
+ You can register your own model as long as it fulfills the `ImageEmbedder` interface.
104
+
105
+ ```python
106
+ import torch
107
+ from viper.registry import register_embedder_factory
108
+ from viper.types import ImageEmbedder
109
+
110
+ class MyEmbedder(torch.nn.Module):
111
+ @property
112
+ def name(self) -> str:
113
+ return "myembedder"
114
+
115
+ @property
116
+ def vector_size(self) -> int:
117
+ return 1024
118
+
119
+ @property
120
+ def embedder_parameters(self) -> dict:
121
+ return {"backbone": "resnet50"}
122
+
123
+ @property
124
+ def device(self) -> str:
125
+ return next(self.parameters()).device.type
126
+
127
+ def forward(self, images: torch.Tensor) -> torch.Tensor:
128
+ # return embeddings of shape B x 1024
129
+ ...
130
+
131
+ # Factory that returns an ImageEmbedder instance
132
+ @register_embedder_factory(key="myembedder")
133
+ def load_myembedder() -> ImageEmbedder:
134
+ model = MyEmbedder()
135
+ return model
136
+
137
+ ```
138
+
139
+ You can then retrieve a factory via `get_embedder_factory("myembedder")` like the built-in models.
140
+
141
+
142
+ ## Testing
143
+
144
+ The repository includes tests for the registry and embedder factories.[^10]
145
+
146
+ To run them:
147
+
148
+ ```bash
149
+ uv run pytest
150
+ ```
151
+
152
+
153
+ ## Acknowledgements
154
+
155
+ This package reuses ideas, code, and checkpoints from several excellent VPR projects.
156
+ Please cite and credit the original works when using the corresponding models.
157
+
158
+ - [AnyLoc](https://github.com/AnyLoc/DINO)
159
+ - [CliqueMining](https://github.com/serizba/clique-mining)
160
+ - [CosPlace](https://github.com/gmberton/CosPlace)
161
+ - [EigenPlaces](https://github.com/gmberton/EigenPlaces)
162
+ - [MegaLoc](https://github.com/gmberton/MegaLoc)
163
+ - [MixVPR](https://github.com/amaralibey/MixVPR)
164
+ - [NetVLAD](https://github.com/cvg/Hierarchical-Localization)
165
+ - [SALAD](https://github.com/serizba/salad)
@@ -0,0 +1,142 @@
1
+ # Viper: A Common Image Embedder Interface for Visual Place Recognition
2
+
3
+ ![ci](https://github.com/markvilar/viper/actions/workflows/ubuntu.yml/badge.svg)
4
+
5
+ This Python package provides a **unified** image embedder interface for visual place recognition (VPR), along with wrapper implementations of several state-of-the-art VPR models so they all expose the same API.
6
+ It also includes a lightweight registry mechanism that lets you register custom embedders and retrieve them by string key.
7
+
8
+ ## Features
9
+
10
+ - Common `ImageEmbedder` protocol for VPR models (name, vector size, device, call semantics).
11
+ - Eight wrapper models adapting popular VPR architectures to this interface:
12
+ - AnyLoc
13
+ - CliqueMining
14
+ - CosPlace
15
+ - EigenPlaces
16
+ - MegaLoc
17
+ - MixVPR
18
+ - NetVLAD
19
+ - SALAD
20
+ - Simple registry (`register_embedder_factory` / `get_embedder_factory`) for instantiating embedders by key.
21
+
22
+ ## Installation
23
+
24
+ The package is configured as a standard Python project via `pyproject.toml` and adds support for the `uv` package manager.
25
+
26
+ You can install it in editable mode for development:
27
+
28
+ ```bash
29
+ uv sync
30
+ ```
31
+
32
+ **Disclaimer:** AnyLoc, CliqueMining, MegaLoc, and SALAD require CUDA to run.
33
+
34
+ ## Usage
35
+
36
+ ### Loading a built-in embedder
37
+
38
+ The recommended way to construct models is through the embedder registry.
39
+
40
+ ```python
41
+ import viper
42
+
43
+ factory: viper.ImageEmbedderFactory = viper.get_embedder_factory("salad") # or "mixvpr", "netvlad", "eigenplaces", ...
44
+ embedder: viper.ImageEmbedder = factory()
45
+
46
+ print(embedder.name) # "salad"
47
+ print(embedder.vector_size) # 8448
48
+ print(embedder.embedder_parameters) # dict of parameters
49
+ print(embedder.device) # "cpu" or "cuda"
50
+ ```
51
+
52
+ All embedders implement the `ImageEmbedder` protocol, which exposes:
53
+
54
+ - `name: str`
55
+ - `vector_size: int` (embedding dimension)
56
+ - `embedder_parameters: dict[str, Any]` (model-specific metadata such as backbone, descriptor size, etc.)
57
+ - `device: str` (e.g. `"cpu"` or `"cuda:0"`)
58
+ - `__call__(images: torch.Tensor) -> torch.Tensor` (batched embedding, `B x C x H x W -> B x E`)
59
+
60
+
61
+ ### Embedding a batch of images
62
+
63
+ All wrappers expect a batch of images as a float tensor of shape `B x C x H x W` with values in `[0.0, 1.0]` (for NetVLAD this is explicitly asserted).
64
+
65
+ ```python
66
+ import torch
67
+ import viper
68
+
69
+ factory: viper.ImageEmbedderFactory = viper.get_embedder_factory("netvlad")
70
+ embedder: viper.ImageEmbedder = factory()
71
+ images: torch.Tensor = torch.rand(8, 3, 480, 640) # example batch, normalized to [0, 1]
72
+ embeddings: torch.Tensor = embedder(images) # shape: (8, embedder.vectorsize)
73
+ ```
74
+
75
+ Wrappers handle grayscale input by converting 1-channel batches to 3-channel RGB internally.
76
+ Some models also resize images so that height/width are multiples of 14, matching DINOv2 backbone constraints.
77
+
78
+ ### Registering a custom model
79
+
80
+ You can register your own model as long as it fulfills the `ImageEmbedder` interface.
81
+
82
+ ```python
83
+ import torch
84
+ from viper.registry import register_embedder_factory
85
+ from viper.types import ImageEmbedder
86
+
87
+ class MyEmbedder(torch.nn.Module):
88
+ @property
89
+ def name(self) -> str:
90
+ return "myembedder"
91
+
92
+ @property
93
+ def vector_size(self) -> int:
94
+ return 1024
95
+
96
+ @property
97
+ def embedder_parameters(self) -> dict:
98
+ return {"backbone": "resnet50"}
99
+
100
+ @property
101
+ def device(self) -> str:
102
+ return next(self.parameters()).device.type
103
+
104
+ def forward(self, images: torch.Tensor) -> torch.Tensor:
105
+ # return embeddings of shape B x 1024
106
+ ...
107
+
108
+ # Factory that returns an ImageEmbedder instance
109
+ @register_embedder_factory(key="myembedder")
110
+ def load_myembedder() -> ImageEmbedder:
111
+ model = MyEmbedder()
112
+ return model
113
+
114
+ ```
115
+
116
+ You can then retrieve a factory via `get_embedder_factory("myembedder")` like the built-in models.
117
+
118
+
119
+ ## Testing
120
+
121
+ The repository includes tests for the registry and embedder factories.[^10]
122
+
123
+ To run them:
124
+
125
+ ```bash
126
+ uv run pytest
127
+ ```
128
+
129
+
130
+ ## Acknowledgements
131
+
132
+ This package reuses ideas, code, and checkpoints from several excellent VPR projects.
133
+ Please cite and credit the original works when using the corresponding models.
134
+
135
+ - [AnyLoc](https://github.com/AnyLoc/DINO)
136
+ - [CliqueMining](https://github.com/serizba/clique-mining)
137
+ - [CosPlace](https://github.com/gmberton/CosPlace)
138
+ - [EigenPlaces](https://github.com/gmberton/EigenPlaces)
139
+ - [MegaLoc](https://github.com/gmberton/MegaLoc)
140
+ - [MixVPR](https://github.com/amaralibey/MixVPR)
141
+ - [NetVLAD](https://github.com/cvg/Hierarchical-Localization)
142
+ - [SALAD](https://github.com/serizba/salad)
@@ -0,0 +1,50 @@
1
+ [project]
2
+ name = "libviper"
3
+ version = "0.1.0"
4
+ description = "A package for visual place recognition (VPR) models."
5
+ readme = "README.md"
6
+ authors = [
7
+ { name = "Martin Kvisvik Larsen", email = "martin.kvisvik.larsen@hotmail.com" }
8
+ ]
9
+ requires-python = ">=3.12"
10
+ dependencies = [
11
+ "imageio>=2.37.0",
12
+ "loguru>=0.7.3",
13
+ "msgspec>=0.19.0",
14
+ "pytest>=8.3.4",
15
+ "python-dotenv>=1.0.1",
16
+ "ruff>=0.8.3",
17
+ "torch>=2.7.0",
18
+ "torchvision>=0.24.0",
19
+ "tqdm>=4.67.1",
20
+ "scipy>=1.16.0",
21
+ "einops>=0.8.1", # NOTE: required by anyloc
22
+ "xformers>=0.0.30", # NOTE: required by megaloc
23
+ "pytorch-lightning>=2.6.0", # NOTE: required by cliquemining
24
+ "fast-pytorch-kmeans>=0.2.2", # NOTE: required by anyloc
25
+ ]
26
+
27
+ [project.scripts]
28
+
29
+ [build-system]
30
+ requires = ["hatchling"]
31
+ build-backend = "hatchling.build"
32
+
33
+ [tool.hatch.build]
34
+ include = ["src/viper"]
35
+
36
+ [tool.hatch.build.targets.wheel]
37
+ packages = ["src/viper"]
38
+
39
+ [tool.uv]
40
+ package = true
41
+
42
+ [tool.uv.sources]
43
+
44
+ [tool.ruff]
45
+ line-length = 88 # default line width of ruff format
46
+
47
+ [dependency-groups]
48
+ dev = [
49
+ "ruff>=0.8.3",
50
+ ]
@@ -0,0 +1,15 @@
1
+ """
2
+ Package for visual place recognition (VPR) models.
3
+ """
4
+
5
+ import viper.models as models # noqa: F401
6
+
7
+ from .registry import FactoryRegistry as FactoryRegistry
8
+ from .registry import register_embedder_factory as register_embedder_factory
9
+ from .registry import get_embedder_factory_registry as get_embedder_factory_registry
10
+ from .registry import get_embedder_factory as get_embedder_factory
11
+
12
+ from .types import ImageEmbedder as ImageEmbedder
13
+ from .types import ImageEmbedderFactory as ImageEmbedderFactory
14
+
15
+ __all__ = []
@@ -0,0 +1,15 @@
1
+ """
2
+ Package with vipers builtin VPR models.
3
+ """
4
+
5
+ import viper.models.anyloc as anyloc # noqa: F401
6
+ import viper.models.cliquemining as cliquemining # noqa: F401
7
+ import viper.models.cosplace as cosplace # noqa: F401
8
+ import viper.models.eigenplaces as eigenplaces # noqa: F401
9
+ import viper.models.megaloc as megaloc # noqa: F401
10
+ import viper.models.mixvpr as mixvpr # noqa: F401
11
+ import viper.models.netvlad as netvlad # noqa: F401
12
+ import viper.models.salad as salad # noqa: F401
13
+
14
+
15
+ __all__ = []
@@ -0,0 +1,92 @@
1
+ """Module for the Anyloc VPR model."""
2
+
3
+ import typing
4
+ import torch
5
+
6
+ from viper.registry import register_embedder_factory
7
+ from viper.types import ImageEmbedder
8
+ from .helpers import convert_grayscale_batch_to_rgb
9
+ from .helpers import calculate_image_size_dinov2
10
+ from .helpers import resize_image_batch
11
+
12
+
13
+ class AnyLocWrapper(torch.nn.Module):
14
+ """
15
+ Class representing a wrapper for the AnyLoc model.
16
+ """
17
+
18
+ def __init__(self, impl: torch.nn.Module) -> None:
19
+ """Initializer method."""
20
+ super().__init__()
21
+ self._impl = impl
22
+ # NOTE: Add dummy parameter to infer device
23
+ self._param = torch.nn.Parameter(torch.tensor(1.0))
24
+
25
+ @property
26
+ def name(self) -> str:
27
+ """Returns the name of the embedder."""
28
+ return "anyloc"
29
+
30
+ @property
31
+ def vector_size(self) -> int:
32
+ """Returns the size, i.e. dimensions, of the image embeddings."""
33
+ return self._impl.vlad.num_clusters * self._impl.vlad.desc_dim
34
+
35
+ @property
36
+ def embedder_parameters(self) -> dict[str, typing.Any]:
37
+ """Returns the parameters of the embedder."""
38
+ return {
39
+ "num_clusters": self._impl.vlad.num_clusters,
40
+ "descriptor_dimensions": self._impl.vlad.desc_dim,
41
+ }
42
+
43
+ @property
44
+ def device(self) -> str:
45
+ """Returns the device of the embedder."""
46
+ return next(self.parameters()).device
47
+
48
+ def __call__(self, images: torch.Tensor) -> torch.Tensor:
49
+ """
50
+ Embeddes a batch of images.
51
+ :arg images: tensor of shape BxCxHxW
52
+ """
53
+ return self.forward(images)
54
+
55
+ def forward(self, images: torch.Tensor) -> torch.Tensor:
56
+ """
57
+ Forwards a batch of images through the module.
58
+
59
+ Arguments:
60
+ images: batch of images - shape BxCxHxW
61
+ Returns:
62
+ batch of image embeddings - shape BxE
63
+ """
64
+ # If the image batch is grayscale, convert to 3 channels
65
+ if images.shape[1] == 1:
66
+ images: torch.Tensor = convert_grayscale_batch_to_rgb(images)
67
+
68
+ assert images.dim() == 4, f"invalid batch dimensions: {images.dim()}"
69
+ assert images.shape[1] == 3, f"invalid image batch channels: {images.shape[1]}"
70
+
71
+ desired_image_size: tuple[int, int] = calculate_image_size_dinov2(images)
72
+ images_resized: torch.Tensor = resize_image_batch(images, desired_image_size)
73
+
74
+ return self._impl(images_resized)
75
+
76
+
77
+ @register_embedder_factory(key="anyloc")
78
+ def load_anyloc() -> ImageEmbedder:
79
+ """Loads an AnyLoc model."""
80
+ # NOTE: AnyLoc requires CUDA to run, hence we assert
81
+ if not torch.cuda.is_available():
82
+ raise RuntimeError("CUDA is required for AnyLoc but is not available.")
83
+
84
+ impl: torch.nn.Module = torch.hub.load(
85
+ "AnyLoc/DINO",
86
+ "get_vlad_model",
87
+ backbone="DINOv2",
88
+ domain="unstructured",
89
+ device="cuda",
90
+ )
91
+ wrapper: AnyLocWrapper = AnyLocWrapper(impl=impl).eval()
92
+ return wrapper