torch-2dtm 0.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,105 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ env/
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+
28
+ .DS_Store
29
+
30
+ # PyInstaller
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .coverage
42
+ .coverage.*
43
+ .cache
44
+ nosetests.xml
45
+ coverage.xml
46
+ *.cover
47
+ .hypothesis/
48
+ .pytest_cache/
49
+
50
+ # Files downloaded for unit tests
51
+ **/tests/tmp/
52
+
53
+ # Translations
54
+ *.mo
55
+ *.pot
56
+
57
+ # Django stuff:
58
+ *.log
59
+ local_settings.py
60
+
61
+ # Flask stuff:
62
+ instance/
63
+ .webassets-cache
64
+
65
+ # Scrapy stuff:
66
+ .scrapy
67
+
68
+ # Sphinx documentation
69
+ docs/_build/
70
+
71
+ # PyBuilder
72
+ target/
73
+
74
+ # Jupyter Notebook
75
+ .ipynb_checkpoints
76
+
77
+ # dotenv
78
+ .env
79
+
80
+ # virtualenv
81
+ .venv
82
+ venv/
83
+ ENV/
84
+
85
+ # Spyder project settings
86
+ .spyderproject
87
+ .spyproject
88
+
89
+ # Rope project settings
90
+ .ropeproject
91
+
92
+ # mkdocs documentation
93
+ /site
94
+
95
+ # mypy
96
+ .mypy_cache/
97
+
98
+ # ruff
99
+ .ruff_cache/
100
+
101
+ # IDEs
102
+ .idea/
103
+ .vscode/
104
+
105
+ lightning_logs/
@@ -0,0 +1,29 @@
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2020, TeamTomo
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ 3. Neither the name of the copyright holder nor the names of its
17
+ contributors may be used to endorse or promote products derived from
18
+ this software without specific prior written permission.
19
+
20
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@@ -0,0 +1,96 @@
1
+ Metadata-Version: 2.4
2
+ Name: torch-2dtm
3
+ Version: 0.5.0
4
+ Summary: 2D template matching in pytorch
5
+ Project-URL: homepage, https://github.com/teamtomo/teamtomo
6
+ Project-URL: repository, https://github.com/teamtomo/teamtomo
7
+ Author-email: Josh Dickerson <jdickerson@berkeley.edu>, Matthew Giammar <matthew_giammar@berkeley.edu>, Alister Burt <alisterburt@gmail.com>
8
+ License: BSD-3-Clause
9
+ License-File: LICENSE
10
+ Classifier: Development Status :: 3 - Alpha
11
+ Classifier: License :: OSI Approved :: BSD License
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Classifier: Programming Language :: Python :: 3.13
16
+ Classifier: Programming Language :: Python :: 3.14
17
+ Classifier: Typing :: Typed
18
+ Requires-Python: >=3.11
19
+ Requires-Dist: einops
20
+ Requires-Dist: setuptools
21
+ Requires-Dist: torch
22
+ Requires-Dist: torch-fourier-slice
23
+ Description-Content-Type: text/markdown
24
+
25
+ # torch-2dtm
26
+
27
+ [![License](https://img.shields.io/pypi/l/torch-2dtm.svg?color=green)](https://github.com/teamtomo/torch-2dtm/raw/main/LICENSE)
28
+ [![PyPI](https://img.shields.io/pypi/v/torch-2dtm.svg?color=green)](https://pypi.org/project/torch-2dtm)
29
+ [![Python Version](https://img.shields.io/pypi/pyversions/torch-2dtm.svg?color=green)](https://python.org)
30
+ [![CI](https://github.com/teamtomo/torch-2dtm/actions/workflows/ci.yml/badge.svg)](https://github.com/teamtomo/torch-2dtm/actions/workflows/ci.yml)
31
+ [![codecov](https://codecov.io/gh/teamtomo/torch-2dtm/branch/main/graph/badge.svg)](https://codecov.io/gh/teamtomo/torch-2dtm)
32
+
33
+ ## Overview
34
+
35
+ torch-2dtm is a Python package for efficient templating matching of
36
+ 2D projections of a 3D template with a 2D image in PyTorch.
37
+
38
+ This is implemented for cryo-EM applications, see
39
+ [Rickgauer et al. 2017 eLife](https://doi.org/10.7554/eLife.25648) for details.
40
+
41
+ ## Features
42
+
43
+ - Fast 2D template matching using Fourier transforms
44
+ - Batch processing over orientations
45
+ - Batch processing over Fourier space filters (e.g. for defocus sweeps)
46
+ - GPU acceleration through PyTorch
47
+
48
+ Projections are calculated on-the-fly using
49
+ [*torch-fourier-slice*](https://github.com/teamtomo/torch-fourier-slice).
50
+
51
+ ## Installation
52
+
53
+ ```bash
54
+ pip install torch-2dtm
55
+ ```
56
+
57
+ ## Basic Usage
58
+
59
+ ```python
60
+ import torch
61
+ import torch_2dtm
62
+ from scipy.stats import special_ortho_group
63
+
64
+ # Create random test data
65
+ # 1. Create a random image and compute its FFT
66
+ image_size = (128, 128)
67
+ image = torch.randn(*image_size, dtype=torch.float32)
68
+ image_dft = torch.fft.rfftn(image, dim=(0, 1)) # Shape: (128, 65)
69
+
70
+ # 2. Create a random 3D template and compute its FFT
71
+ template_size = (64, 64, 64)
72
+ template = torch.randn(*template_size, dtype=torch.float32)
73
+ template_dft = torch.fft.rfftn(template, dim=(0, 1, 2)) # Shape: (64, 64, 33)
74
+
75
+ # 3. Create a batch of random rotation matrices with shape (b, 3, 3)
76
+ num_orientations = 10
77
+ rotation_matrices = torch.tensor(special_ortho_group.rvs(size=num_orientations, dim=3), dtype=torch.float32)
78
+
79
+ # 4. Create an arbitrary stack of Fourier space filters (identity filter in this example)
80
+ # These filters operate on rffts of the 2D projection images
81
+ # Filter shape: (..., h, w // 2 + 1)
82
+ filters = torch.ones(template_size[0], template_size[1] // 2 + 1, dtype=torch.complex64)
83
+
84
+ # Perform template matching
85
+ cross_correlation = torch_2dtm.match_template_dft_2d(
86
+ image_dft=image_dft,
87
+ template_dft=template_dft,
88
+ rotation_matrices=rotation_matrices,
89
+ filters=filters
90
+ )
91
+ # The result has shape (..., num_orientations, image_height, image_width)
92
+ ```
93
+
94
+ ## License
95
+
96
+ This project is licensed under the BSD 3-Clause License - see the LICENSE file for details.
@@ -0,0 +1,72 @@
1
+ # torch-2dtm
2
+
3
+ [![License](https://img.shields.io/pypi/l/torch-2dtm.svg?color=green)](https://github.com/teamtomo/torch-2dtm/raw/main/LICENSE)
4
+ [![PyPI](https://img.shields.io/pypi/v/torch-2dtm.svg?color=green)](https://pypi.org/project/torch-2dtm)
5
+ [![Python Version](https://img.shields.io/pypi/pyversions/torch-2dtm.svg?color=green)](https://python.org)
6
+ [![CI](https://github.com/teamtomo/torch-2dtm/actions/workflows/ci.yml/badge.svg)](https://github.com/teamtomo/torch-2dtm/actions/workflows/ci.yml)
7
+ [![codecov](https://codecov.io/gh/teamtomo/torch-2dtm/branch/main/graph/badge.svg)](https://codecov.io/gh/teamtomo/torch-2dtm)
8
+
9
+ ## Overview
10
+
11
+ torch-2dtm is a Python package for efficient templating matching of
12
+ 2D projections of a 3D template with a 2D image in PyTorch.
13
+
14
+ This is implemented for cryo-EM applications, see
15
+ [Rickgauer et al. 2017 eLife](https://doi.org/10.7554/eLife.25648) for details.
16
+
17
+ ## Features
18
+
19
+ - Fast 2D template matching using Fourier transforms
20
+ - Batch processing over orientations
21
+ - Batch processing over Fourier space filters (e.g. for defocus sweeps)
22
+ - GPU acceleration through PyTorch
23
+
24
+ Projections are calculated on-the-fly using
25
+ [*torch-fourier-slice*](https://github.com/teamtomo/torch-fourier-slice).
26
+
27
+ ## Installation
28
+
29
+ ```bash
30
+ pip install torch-2dtm
31
+ ```
32
+
33
+ ## Basic Usage
34
+
35
+ ```python
36
+ import torch
37
+ import torch_2dtm
38
+ from scipy.stats import special_ortho_group
39
+
40
+ # Create random test data
41
+ # 1. Create a random image and compute its FFT
42
+ image_size = (128, 128)
43
+ image = torch.randn(*image_size, dtype=torch.float32)
44
+ image_dft = torch.fft.rfftn(image, dim=(0, 1)) # Shape: (128, 65)
45
+
46
+ # 2. Create a random 3D template and compute its FFT
47
+ template_size = (64, 64, 64)
48
+ template = torch.randn(*template_size, dtype=torch.float32)
49
+ template_dft = torch.fft.rfftn(template, dim=(0, 1, 2)) # Shape: (64, 64, 33)
50
+
51
+ # 3. Create a batch of random rotation matrices with shape (b, 3, 3)
52
+ num_orientations = 10
53
+ rotation_matrices = torch.tensor(special_ortho_group.rvs(size=num_orientations, dim=3), dtype=torch.float32)
54
+
55
+ # 4. Create an arbitrary stack of Fourier space filters (identity filter in this example)
56
+ # These filters operate on rffts of the 2D projection images
57
+ # Filter shape: (..., h, w // 2 + 1)
58
+ filters = torch.ones(template_size[0], template_size[1] // 2 + 1, dtype=torch.complex64)
59
+
60
+ # Perform template matching
61
+ cross_correlation = torch_2dtm.match_template_dft_2d(
62
+ image_dft=image_dft,
63
+ template_dft=template_dft,
64
+ rotation_matrices=rotation_matrices,
65
+ filters=filters
66
+ )
67
+ # The result has shape (..., num_orientations, image_height, image_width)
68
+ ```
69
+
70
+ ## License
71
+
72
+ This project is licensed under the BSD 3-Clause License - see the LICENSE file for details.
@@ -0,0 +1,112 @@
1
+ # https://peps.python.org/pep-0517/
2
+ [build-system]
3
+ requires = ["hatchling", "hatch-vcs"]
4
+ build-backend = "hatchling.build"
5
+
6
+ # https://hatch.pypa.io/latest/config/metadata/
7
+ [tool.hatch.version]
8
+ source = "vcs"
9
+ tag-pattern = "^torch-2dtm@v(?P<version>.+)$"
10
+ fallback-version = "0.5.0"
11
+
12
+ [tool.hatch.version.raw-options]
13
+ search_parent_directories = true
14
+ # Parse tags of the form: <package-name>@v<semver>
15
+ tag_regex = "^torch-2dtm@v(?P<version>\\d+\\.\\d+\\.\\d+.*)$"
16
+ # Constrain git-describe so it only considers TeamTomo's own tags, not other workspace tags.
17
+ # See https://github.com/ofek/hatch-vcs/issues/71
18
+ git_describe_command = "git describe --dirty --tags --long --match 'torch-2dtm@v[0-9]*.[0-9]*.[0-9]*'"
19
+
20
+ # read more about configuring hatch at:
21
+ # https://hatch.pypa.io/latest/config/build/
22
+ [tool.hatch.build.targets.wheel]
23
+ only-include = ["src"]
24
+ sources = ["src"]
25
+
26
+ # https://peps.python.org/pep-0621/
27
+ [project]
28
+ name = "torch-2dtm"
29
+ dynamic = ["version"]
30
+ description = "2D template matching in pytorch"
31
+ readme = "README.md"
32
+ requires-python = ">=3.11"
33
+ license = { text = "BSD-3-Clause" }
34
+ authors = [
35
+ { name = "Josh Dickerson", email = "jdickerson@berkeley.edu" },
36
+ { name = "Matthew Giammar", email = "matthew_giammar@berkeley.edu" },
37
+ { name = "Alister Burt", email = "alisterburt@gmail.com" },
38
+ ]
39
+ # https://pypi.org/classifiers/
40
+ classifiers = [
41
+ "Development Status :: 3 - Alpha",
42
+ "License :: OSI Approved :: BSD License",
43
+ "Programming Language :: Python :: 3",
44
+ "Programming Language :: Python :: 3.11",
45
+ "Programming Language :: Python :: 3.12",
46
+ "Programming Language :: Python :: 3.13",
47
+ "Programming Language :: Python :: 3.14",
48
+ "Typing :: Typed",
49
+ ]
50
+ # add your package dependencies here
51
+ dependencies = [
52
+ "torch",
53
+ "einops",
54
+ "torch-fourier-slice",
55
+ "setuptools", # needed for torch compilation? lol
56
+ ]
57
+
58
+
59
+ # https://peps.python.org/pep-0621/#dependencies-optional-dependencies
60
+ # "extras" (e.g. for `pip install .[test]`)
61
+ [dependency-groups]
62
+ # add dependencies used for testing here
63
+ test = ["pytest", "pytest-cov", "scipy"]
64
+ # add anything else you like to have in your dev environment here
65
+ dev = [
66
+ { include-group = "test" },
67
+ "ipython",
68
+ "pdbpp", # https://github.com/pdbpp/pdbpp
69
+ "rich", # https://github.com/Textualize/rich
70
+ ]
71
+
72
+ [project.urls]
73
+ homepage = "https://github.com/teamtomo/teamtomo"
74
+ repository = "https://github.com/teamtomo/teamtomo"
75
+
76
+ # Entry points
77
+ # https://peps.python.org/pep-0621/#entry-points
78
+ # same as console_scripts entry point
79
+ # [project.scripts]
80
+ # torch-2dtm-cli = "torch_2dtm:main_cli"
81
+
82
+ # [project.entry-points."some.group"]
83
+ # tomatoes = "torch_2dtm:main_tomatoes"
84
+
85
+
86
+ # https://docs.pytest.org/
87
+ [tool.pytest.ini_options]
88
+ minversion = "7.0"
89
+ testpaths = ["tests"]
90
+ filterwarnings = ["error"]
91
+
92
+ # https://coverage.readthedocs.io/
93
+ [tool.coverage.report]
94
+ show_missing = true
95
+ exclude_lines = [
96
+ "pragma: no cover",
97
+ "if TYPE_CHECKING:",
98
+ "@overload",
99
+ "except ImportError",
100
+ "\\.\\.\\.",
101
+ "raise NotImplementedError()",
102
+ "pass",
103
+ ]
104
+
105
+ [tool.coverage.run]
106
+ source = ["torch_2dtm"]
107
+
108
+ # https://github.com/mgedmin/check-manifest#configuration
109
+ # add files that you want check-manifest to explicitly ignore here
110
+ # (files that are in the repo but shouldn't go in the package)
111
+ [tool.check-manifest]
112
+ ignore = [".pre-commit-config.yaml", ".ruff_cache/**/*", "tests/**/*"]
@@ -0,0 +1,9 @@
1
+ """2D template matching in pytorch"""
2
+
3
+ __version__ = "0.1.0"
4
+ __author__ = "Josh Dickerson"
5
+ __email__ = "jdickerson@berkeley.edu"
6
+
7
+ from .cross_correlate import match_template_dft_2d
8
+
9
+ __all__ = ["match_template_dft_2d"]
@@ -0,0 +1,90 @@
1
+ """Cross-correlation functions."""
2
+
3
+ import torch
4
+ import platform
5
+ import einops
6
+ from einops._torch_specific import allow_ops_in_compiled_graph
7
+ from torch_fourier_slice import extract_central_slices_rfft_3d
8
+
9
+ from torch_2dtm.utils import normalize_template_projection
10
+
11
+ # compile normalization utility function
12
+ allow_ops_in_compiled_graph()
13
+ if platform.system() == "Linux":
14
+ COMPILE_BACKEND = "aot_eager" # More stable than inductor on Linux
15
+ else:
16
+ COMPILE_BACKEND = "inductor" # inductor for macOS
17
+
18
+ normalize_template_projection_compiled = torch.compile(
19
+ normalize_template_projection, backend=COMPILE_BACKEND
20
+ )
21
+
22
+
23
+ def match_template_dft_2d(
24
+ image_dft: torch.Tensor,
25
+ template_dft: torch.Tensor,
26
+ rotation_matrices: torch.Tensor,
27
+ filters: torch.Tensor,
28
+ ) -> torch.Tensor:
29
+ """Batched projection and cross-correlation with a set of filters.
30
+
31
+ Note that this function returns a cross-correlation image which is the
32
+ same size as the input image prior to FFT calculation.
33
+
34
+ Parameters
35
+ ----------
36
+ image_dft : torch.Tensor
37
+ `(h_im, w_im // 2 + 1)` fourier transform (rfft) of the real space image.
38
+ Any filters etc are assumed to have already been applied to this image.
39
+ template_dft : torch.Tensor
40
+ `(d, h, w // 2 + 1)` fftshifted fourier transform (rfft) of the real valued template volume to take Fourier
41
+ slices from.
42
+ rotation_matrices : torch.Tensor
43
+ `(b, 3, 3)` batched rotation matrices to rotate slices sampled from the template fourier transform.
44
+ filters : torch.Tensor
45
+ `(..., h, w // 2 + 1)` filters applied to FFT slices which are fftshifted results of a rfft.
46
+
47
+ Returns
48
+ -------
49
+ torch.Tensor
50
+ Cross-correlation of the image with the template volume for each
51
+ orientation and defocus value. Will have shape
52
+ (orientations, defocus_batch, H, W).
53
+ """
54
+ # Grab relevant dimensions
55
+ _, h, w = template_dft.shape
56
+ h_im, w_im = image_dft.shape
57
+ w_im = 2 * (w_im - 1)
58
+ w = 2 * (w - 1)
59
+
60
+ # Extract central slice(s) from the template volume
61
+ fourier_slices = extract_central_slices_rfft_3d(
62
+ volume_rfft=template_dft,
63
+ rotation_matrices=rotation_matrices,
64
+ ) # (b, h, w)
65
+ fourier_slices = torch.fft.ifftshift(fourier_slices, dim=(-2,))
66
+ fourier_slices[..., 0, 0] = 0 + 0j # zero out the DC component (mean zero)
67
+ fourier_slices *= -1 # flip contrast
68
+
69
+ # Apply the projective filters with broadcasting
70
+ filters = einops.rearrange(filters, "... h w -> ... 1 h w")
71
+ fourier_slices = fourier_slices * filters # (..., b, h, w)
72
+
73
+ # Inverse Fourier transform into real space and normalize
74
+ projections = torch.fft.irfftn(fourier_slices, dim=(-2, -1))
75
+ projections = torch.fft.ifftshift(projections, dim=(-2, -1))
76
+ projections = normalize_template_projection_compiled(
77
+ projections, (h, w), (h_im, w_im)
78
+ )
79
+
80
+ # Padded forward Fourier transform for cross-correlation
81
+ projections_dft = torch.fft.rfftn(projections, dim=(-2, -1), s=(h_im, w_im))
82
+
83
+ # Zero the DC component (set mean zero)
84
+ projections_dft[..., 0, 0] = 0 + 0j
85
+
86
+ # Cross correlation step by element-wise multiplication
87
+ projections_dft = image_dft * torch.conj(projections_dft)
88
+ cross_correlation = torch.fft.irfftn(projections_dft, dim=(-2, -1))
89
+
90
+ return cross_correlation # (..., h_im, w_im)
@@ -0,0 +1,5 @@
1
+ You may remove this file if you don't intend to add types to your package
2
+
3
+ Details at:
4
+
5
+ https://mypy.readthedocs.io/en/stable/installed_packages.html#creating-pep-561-compatible-packages
@@ -0,0 +1,88 @@
1
+ """Utility functions associated with backend functions."""
2
+
3
+ import torch
4
+ import einops
5
+
6
+
7
+ def normalize_template_projection(
8
+ projections: torch.Tensor,
9
+ small_shape: tuple[int, int],
10
+ large_shape: tuple[int, int],
11
+ ) -> torch.Tensor:
12
+ r"""Subtract mean of edge values and set variance to 1 (in large shape).
13
+
14
+ This function uses the fact that variance of a sequence, Var(X), is scaled by the
15
+ relative size of the small (unpadded) and large (padded with zeros) space. Some
16
+ negligible error is introduced into the variance (~1e-4) due to this routine.
17
+
18
+ Let $X$ be the large, zero-padded projection and $x$ the small projection each
19
+ with sizes $(H, W)$ and $(h, w)$, respectively. The mean of the zero-padded
20
+ projection in terms of the small projection is:
21
+ .. math::
22
+ \begin{align}
23
+ \mu(X) &= \frac{1}{H \cdot W} \sum_{i=1}^{H} \sum_{j=1}^{W} X_{ij} \\
24
+ \mu(X) &= \frac{1}{H \cdot W} \sum_{i=1}^{h} \sum_{j=1}^{w} X_{ij} + 0 \\
25
+ \mu(X) &= \frac{h \cdot w}{H \cdot W} \mu(x)
26
+ \end{align}
27
+ The variance of the zero-padded projection in terms of the small projection can be
28
+ obtained by:
29
+ .. math::
30
+ \begin{align}
31
+ Var(X) &= \frac{1}{H \cdot W} \sum_{i=1}^{H} \sum_{j=1}^{W} (X_{ij} -
32
+ \mu(X))^2 \\
33
+ Var(X) &= \frac{1}{H \cdot W} \left(\sum_{i=1}^{h}
34
+ \sum_{j=1}^{w} (X_{ij} - \mu(X))^2 +
35
+ \sum_{i=h+1}^{H}\sum_{i=w+1}^{W} \mu(X)^2 \right) \\
36
+ Var(X) &= \frac{1}{H \cdot W} \sum_{i=1}^{h} \sum_{j=1}^{w} (X_{ij} -
37
+ \mu(X))^2 + (H-h)(W-w)\mu(X)^2
38
+ \end{align}
39
+
40
+ Parameters
41
+ ----------
42
+ projections : torch.Tensor
43
+ `(..., h, w)` real-space projections of the template (in small space).
44
+ small_shape : tuple[int, int]
45
+ `(h, w)` shape of the template (in real space).
46
+ large_shape : tuple[int, int]
47
+ `(h_im, w_im)` shape of the image (in real space).
48
+
49
+ Returns
50
+ -------
51
+ projections: torch.Tensor
52
+ `(..., h, w)` edge-mean subtracted projections
53
+ normalized so variance of zero-padded projection would be 1.
54
+ """
55
+ h, w = small_shape
56
+ h_im, w_im = large_shape
57
+
58
+ # Extract edges while preserving batch dimensions
59
+ top_edge = projections[..., 0, :] # shape: (..., w)
60
+ bottom_edge = projections[..., -1, :] # shape: (..., w)
61
+ left_edge = projections[..., 1:-1, 0] # shape: (..., h-2)
62
+ right_edge = projections[..., 1:-1, -1] # shape: (..., h-2)
63
+ edge_pixels = torch.concatenate(
64
+ [top_edge, bottom_edge, left_edge, right_edge], dim=-1
65
+ ) # shape: (..., w + w + h-2 + h-2)
66
+
67
+ # Subtract the edge pixel mean and calculate variance of small, unpadded projection
68
+ edge_mean = einops.reduce(edge_pixels, "... b -> ...", reduction="mean")
69
+ edge_mean = einops.rearrange(edge_mean, "... -> ... 1 1")
70
+ projections -= edge_mean
71
+
72
+ # Fast calculation of mean/var using Torch + appropriate scaling.
73
+ relative_size = h * w / (h_im * w_im)
74
+ per_image_mean = einops.reduce(projections, "... h w -> ...", reduction="mean")
75
+ per_image_mean *= relative_size**2
76
+
77
+ # First term of the variance calculation
78
+ diff = projections - einops.rearrange(per_image_mean, "... -> ... 1 1")
79
+ per_image_variance = einops.reduce(
80
+ diff**2, pattern="... h w -> ...", reduction="sum"
81
+ )
82
+
83
+ # Add the second term of the variance calculation
84
+ per_image_variance += (h_im - h) * (w_im - w) * per_image_mean**2
85
+ per_image_variance /= h_im * w_im
86
+ per_image_variance = einops.rearrange(per_image_variance, "... -> ... 1 1")
87
+
88
+ return projections / torch.sqrt(per_image_variance)
@@ -0,0 +1,40 @@
1
+ import torch
2
+ import torch_2dtm
3
+
4
+ from scipy.stats import special_ortho_group
5
+
6
+
7
+ def test_template_match_dft_2d():
8
+ # Create random test data
9
+ # 1. Create a random image and compute its FFT
10
+ image_size = (128, 128)
11
+ image = torch.randn(*image_size, dtype=torch.float32)
12
+ image_dft = torch.fft.rfftn(image, dim=(0, 1)) # Shape: (128, 65)
13
+
14
+ # 2. Create a random 3D template and compute its FFT
15
+ template_size = (64, 64, 64)
16
+ template = torch.randn(*template_size, dtype=torch.float32)
17
+ template_dft = torch.fft.rfftn(template, dim=(0, 1, 2)) # Shape: (64, 64, 33)
18
+
19
+ # 3. Create a batch of random rotation matrices with shape (b, 3, 3)
20
+ num_orientations = 10
21
+ rotation_matrices = torch.tensor(
22
+ special_ortho_group.rvs(size=num_orientations, dim=3), dtype=torch.float32
23
+ )
24
+
25
+ # 4. Create an arbitrary stack of Fourier space filters (identity filter in this example)
26
+ # These filters operate on rffts of the 2D projection images
27
+ # Filter shape: (..., h, w // 2 + 1)
28
+ filters_shape = (5, 4, 3, template_size[0], template_size[1] // 2 + 1)
29
+ filters = torch.ones(filters_shape, dtype=torch.complex64)
30
+
31
+ # Perform template matching
32
+ cross_correlation = torch_2dtm.match_template_dft_2d(
33
+ image_dft=image_dft,
34
+ template_dft=template_dft,
35
+ rotation_matrices=rotation_matrices,
36
+ filters=filters,
37
+ )
38
+
39
+ # correct output shape is (..., num_orientations, h, w)
40
+ assert cross_correlation.shape == (5, 4, 3, num_orientations, *image_size)