torch-2dtm 0.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_2dtm-0.5.0/.gitignore +105 -0
- torch_2dtm-0.5.0/LICENSE +29 -0
- torch_2dtm-0.5.0/PKG-INFO +96 -0
- torch_2dtm-0.5.0/README.md +72 -0
- torch_2dtm-0.5.0/pyproject.toml +112 -0
- torch_2dtm-0.5.0/src/torch_2dtm/__init__.py +9 -0
- torch_2dtm-0.5.0/src/torch_2dtm/cross_correlate.py +90 -0
- torch_2dtm-0.5.0/src/torch_2dtm/py.typed +5 -0
- torch_2dtm-0.5.0/src/torch_2dtm/utils.py +88 -0
- torch_2dtm-0.5.0/tests/test_torch_2dtm.py +40 -0
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
# Byte-compiled / optimized / DLL files
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.py[cod]
|
|
4
|
+
*$py.class
|
|
5
|
+
|
|
6
|
+
# C extensions
|
|
7
|
+
*.so
|
|
8
|
+
|
|
9
|
+
# Distribution / packaging
|
|
10
|
+
.Python
|
|
11
|
+
env/
|
|
12
|
+
build/
|
|
13
|
+
develop-eggs/
|
|
14
|
+
dist/
|
|
15
|
+
downloads/
|
|
16
|
+
eggs/
|
|
17
|
+
.eggs/
|
|
18
|
+
lib/
|
|
19
|
+
lib64/
|
|
20
|
+
parts/
|
|
21
|
+
sdist/
|
|
22
|
+
var/
|
|
23
|
+
wheels/
|
|
24
|
+
*.egg-info/
|
|
25
|
+
.installed.cfg
|
|
26
|
+
*.egg
|
|
27
|
+
|
|
28
|
+
.DS_Store
|
|
29
|
+
|
|
30
|
+
# PyInstaller
|
|
31
|
+
*.manifest
|
|
32
|
+
*.spec
|
|
33
|
+
|
|
34
|
+
# Installer logs
|
|
35
|
+
pip-log.txt
|
|
36
|
+
pip-delete-this-directory.txt
|
|
37
|
+
|
|
38
|
+
# Unit test / coverage reports
|
|
39
|
+
htmlcov/
|
|
40
|
+
.tox/
|
|
41
|
+
.coverage
|
|
42
|
+
.coverage.*
|
|
43
|
+
.cache
|
|
44
|
+
nosetests.xml
|
|
45
|
+
coverage.xml
|
|
46
|
+
*.cover
|
|
47
|
+
.hypothesis/
|
|
48
|
+
.pytest_cache/
|
|
49
|
+
|
|
50
|
+
# Files downloaded for unit tests
|
|
51
|
+
**/tests/tmp/
|
|
52
|
+
|
|
53
|
+
# Translations
|
|
54
|
+
*.mo
|
|
55
|
+
*.pot
|
|
56
|
+
|
|
57
|
+
# Django stuff:
|
|
58
|
+
*.log
|
|
59
|
+
local_settings.py
|
|
60
|
+
|
|
61
|
+
# Flask stuff:
|
|
62
|
+
instance/
|
|
63
|
+
.webassets-cache
|
|
64
|
+
|
|
65
|
+
# Scrapy stuff:
|
|
66
|
+
.scrapy
|
|
67
|
+
|
|
68
|
+
# Sphinx documentation
|
|
69
|
+
docs/_build/
|
|
70
|
+
|
|
71
|
+
# PyBuilder
|
|
72
|
+
target/
|
|
73
|
+
|
|
74
|
+
# Jupyter Notebook
|
|
75
|
+
.ipynb_checkpoints
|
|
76
|
+
|
|
77
|
+
# dotenv
|
|
78
|
+
.env
|
|
79
|
+
|
|
80
|
+
# virtualenv
|
|
81
|
+
.venv
|
|
82
|
+
venv/
|
|
83
|
+
ENV/
|
|
84
|
+
|
|
85
|
+
# Spyder project settings
|
|
86
|
+
.spyderproject
|
|
87
|
+
.spyproject
|
|
88
|
+
|
|
89
|
+
# Rope project settings
|
|
90
|
+
.ropeproject
|
|
91
|
+
|
|
92
|
+
# mkdocs documentation
|
|
93
|
+
/site
|
|
94
|
+
|
|
95
|
+
# mypy
|
|
96
|
+
.mypy_cache/
|
|
97
|
+
|
|
98
|
+
# ruff
|
|
99
|
+
.ruff_cache/
|
|
100
|
+
|
|
101
|
+
# IDEs
|
|
102
|
+
.idea/
|
|
103
|
+
.vscode/
|
|
104
|
+
|
|
105
|
+
lightning_logs/
|
torch_2dtm-0.5.0/LICENSE
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
BSD 3-Clause License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2020, TeamTomo
|
|
4
|
+
All rights reserved.
|
|
5
|
+
|
|
6
|
+
Redistribution and use in source and binary forms, with or without
|
|
7
|
+
modification, are permitted provided that the following conditions are met:
|
|
8
|
+
|
|
9
|
+
1. Redistributions of source code must retain the above copyright notice, this
|
|
10
|
+
list of conditions and the following disclaimer.
|
|
11
|
+
|
|
12
|
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
13
|
+
this list of conditions and the following disclaimer in the documentation
|
|
14
|
+
and/or other materials provided with the distribution.
|
|
15
|
+
|
|
16
|
+
3. Neither the name of the copyright holder nor the names of its
|
|
17
|
+
contributors may be used to endorse or promote products derived from
|
|
18
|
+
this software without specific prior written permission.
|
|
19
|
+
|
|
20
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
21
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
22
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
23
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
24
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
25
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
26
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
27
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
28
|
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
29
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torch-2dtm
|
|
3
|
+
Version: 0.5.0
|
|
4
|
+
Summary: 2D template matching in pytorch
|
|
5
|
+
Project-URL: homepage, https://github.com/teamtomo/teamtomo
|
|
6
|
+
Project-URL: repository, https://github.com/teamtomo/teamtomo
|
|
7
|
+
Author-email: Josh Dickerson <jdickerson@berkeley.edu>, Matthew Giammar <matthew_giammar@berkeley.edu>, Alister Burt <alisterburt@gmail.com>
|
|
8
|
+
License: BSD-3-Clause
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Classifier: Development Status :: 3 - Alpha
|
|
11
|
+
Classifier: License :: OSI Approved :: BSD License
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
17
|
+
Classifier: Typing :: Typed
|
|
18
|
+
Requires-Python: >=3.11
|
|
19
|
+
Requires-Dist: einops
|
|
20
|
+
Requires-Dist: setuptools
|
|
21
|
+
Requires-Dist: torch
|
|
22
|
+
Requires-Dist: torch-fourier-slice
|
|
23
|
+
Description-Content-Type: text/markdown
|
|
24
|
+
|
|
25
|
+
# torch-2dtm
|
|
26
|
+
|
|
27
|
+
[](https://github.com/teamtomo/torch-2dtm/raw/main/LICENSE)
|
|
28
|
+
[](https://pypi.org/project/torch-2dtm)
|
|
29
|
+
[](https://python.org)
|
|
30
|
+
[](https://github.com/teamtomo/torch-2dtm/actions/workflows/ci.yml)
|
|
31
|
+
[](https://codecov.io/gh/teamtomo/torch-2dtm)
|
|
32
|
+
|
|
33
|
+
## Overview
|
|
34
|
+
|
|
35
|
+
torch-2dtm is a Python package for efficient templating matching of
|
|
36
|
+
2D projections of a 3D template with a 2D image in PyTorch.
|
|
37
|
+
|
|
38
|
+
This is implemented for cryo-EM applications, see
|
|
39
|
+
[Rickgauer et al. 2017 eLife](https://doi.org/10.7554/eLife.25648) for details.
|
|
40
|
+
|
|
41
|
+
## Features
|
|
42
|
+
|
|
43
|
+
- Fast 2D template matching using Fourier transforms
|
|
44
|
+
- Batch processing over orientations
|
|
45
|
+
- Batch processing over Fourier space filters (e.g. for defocus sweeps)
|
|
46
|
+
- GPU acceleration through PyTorch
|
|
47
|
+
|
|
48
|
+
Projections are calculated on-the-fly using
|
|
49
|
+
[*torch-fourier-slice*](https://github.com/teamtomo/torch-fourier-slice).
|
|
50
|
+
|
|
51
|
+
## Installation
|
|
52
|
+
|
|
53
|
+
```bash
|
|
54
|
+
pip install torch-2dtm
|
|
55
|
+
```
|
|
56
|
+
|
|
57
|
+
## Basic Usage
|
|
58
|
+
|
|
59
|
+
```python
|
|
60
|
+
import torch
|
|
61
|
+
import torch_2dtm
|
|
62
|
+
from scipy.stats import special_ortho_group
|
|
63
|
+
|
|
64
|
+
# Create random test data
|
|
65
|
+
# 1. Create a random image and compute its FFT
|
|
66
|
+
image_size = (128, 128)
|
|
67
|
+
image = torch.randn(*image_size, dtype=torch.float32)
|
|
68
|
+
image_dft = torch.fft.rfftn(image, dim=(0, 1)) # Shape: (128, 65)
|
|
69
|
+
|
|
70
|
+
# 2. Create a random 3D template and compute its FFT
|
|
71
|
+
template_size = (64, 64, 64)
|
|
72
|
+
template = torch.randn(*template_size, dtype=torch.float32)
|
|
73
|
+
template_dft = torch.fft.rfftn(template, dim=(0, 1, 2)) # Shape: (64, 64, 33)
|
|
74
|
+
|
|
75
|
+
# 3. Create a batch of random rotation matrices with shape (b, 3, 3)
|
|
76
|
+
num_orientations = 10
|
|
77
|
+
rotation_matrices = torch.tensor(special_ortho_group.rvs(size=num_orientations, dim=3), dtype=torch.float32)
|
|
78
|
+
|
|
79
|
+
# 4. Create an arbitrary stack of Fourier space filters (identity filter in this example)
|
|
80
|
+
# These filters operate on rffts of the 2D projection images
|
|
81
|
+
# Filter shape: (..., h, w // 2 + 1)
|
|
82
|
+
filters = torch.ones(template_size[0], template_size[1] // 2 + 1, dtype=torch.complex64)
|
|
83
|
+
|
|
84
|
+
# Perform template matching
|
|
85
|
+
cross_correlation = torch_2dtm.match_template_dft_2d(
|
|
86
|
+
image_dft=image_dft,
|
|
87
|
+
template_dft=template_dft,
|
|
88
|
+
rotation_matrices=rotation_matrices,
|
|
89
|
+
filters=filters
|
|
90
|
+
)
|
|
91
|
+
# The result has shape (..., num_orientations, image_height, image_width)
|
|
92
|
+
```
|
|
93
|
+
|
|
94
|
+
## License
|
|
95
|
+
|
|
96
|
+
This project is licensed under the BSD 3-Clause License - see the LICENSE file for details.
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
# torch-2dtm
|
|
2
|
+
|
|
3
|
+
[](https://github.com/teamtomo/torch-2dtm/raw/main/LICENSE)
|
|
4
|
+
[](https://pypi.org/project/torch-2dtm)
|
|
5
|
+
[](https://python.org)
|
|
6
|
+
[](https://github.com/teamtomo/torch-2dtm/actions/workflows/ci.yml)
|
|
7
|
+
[](https://codecov.io/gh/teamtomo/torch-2dtm)
|
|
8
|
+
|
|
9
|
+
## Overview
|
|
10
|
+
|
|
11
|
+
torch-2dtm is a Python package for efficient templating matching of
|
|
12
|
+
2D projections of a 3D template with a 2D image in PyTorch.
|
|
13
|
+
|
|
14
|
+
This is implemented for cryo-EM applications, see
|
|
15
|
+
[Rickgauer et al. 2017 eLife](https://doi.org/10.7554/eLife.25648) for details.
|
|
16
|
+
|
|
17
|
+
## Features
|
|
18
|
+
|
|
19
|
+
- Fast 2D template matching using Fourier transforms
|
|
20
|
+
- Batch processing over orientations
|
|
21
|
+
- Batch processing over Fourier space filters (e.g. for defocus sweeps)
|
|
22
|
+
- GPU acceleration through PyTorch
|
|
23
|
+
|
|
24
|
+
Projections are calculated on-the-fly using
|
|
25
|
+
[*torch-fourier-slice*](https://github.com/teamtomo/torch-fourier-slice).
|
|
26
|
+
|
|
27
|
+
## Installation
|
|
28
|
+
|
|
29
|
+
```bash
|
|
30
|
+
pip install torch-2dtm
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
## Basic Usage
|
|
34
|
+
|
|
35
|
+
```python
|
|
36
|
+
import torch
|
|
37
|
+
import torch_2dtm
|
|
38
|
+
from scipy.stats import special_ortho_group
|
|
39
|
+
|
|
40
|
+
# Create random test data
|
|
41
|
+
# 1. Create a random image and compute its FFT
|
|
42
|
+
image_size = (128, 128)
|
|
43
|
+
image = torch.randn(*image_size, dtype=torch.float32)
|
|
44
|
+
image_dft = torch.fft.rfftn(image, dim=(0, 1)) # Shape: (128, 65)
|
|
45
|
+
|
|
46
|
+
# 2. Create a random 3D template and compute its FFT
|
|
47
|
+
template_size = (64, 64, 64)
|
|
48
|
+
template = torch.randn(*template_size, dtype=torch.float32)
|
|
49
|
+
template_dft = torch.fft.rfftn(template, dim=(0, 1, 2)) # Shape: (64, 64, 33)
|
|
50
|
+
|
|
51
|
+
# 3. Create a batch of random rotation matrices with shape (b, 3, 3)
|
|
52
|
+
num_orientations = 10
|
|
53
|
+
rotation_matrices = torch.tensor(special_ortho_group.rvs(size=num_orientations, dim=3), dtype=torch.float32)
|
|
54
|
+
|
|
55
|
+
# 4. Create an arbitrary stack of Fourier space filters (identity filter in this example)
|
|
56
|
+
# These filters operate on rffts of the 2D projection images
|
|
57
|
+
# Filter shape: (..., h, w // 2 + 1)
|
|
58
|
+
filters = torch.ones(template_size[0], template_size[1] // 2 + 1, dtype=torch.complex64)
|
|
59
|
+
|
|
60
|
+
# Perform template matching
|
|
61
|
+
cross_correlation = torch_2dtm.match_template_dft_2d(
|
|
62
|
+
image_dft=image_dft,
|
|
63
|
+
template_dft=template_dft,
|
|
64
|
+
rotation_matrices=rotation_matrices,
|
|
65
|
+
filters=filters
|
|
66
|
+
)
|
|
67
|
+
# The result has shape (..., num_orientations, image_height, image_width)
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
## License
|
|
71
|
+
|
|
72
|
+
This project is licensed under the BSD 3-Clause License - see the LICENSE file for details.
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
# https://peps.python.org/pep-0517/
|
|
2
|
+
[build-system]
|
|
3
|
+
requires = ["hatchling", "hatch-vcs"]
|
|
4
|
+
build-backend = "hatchling.build"
|
|
5
|
+
|
|
6
|
+
# https://hatch.pypa.io/latest/config/metadata/
|
|
7
|
+
[tool.hatch.version]
|
|
8
|
+
source = "vcs"
|
|
9
|
+
tag-pattern = "^torch-2dtm@v(?P<version>.+)$"
|
|
10
|
+
fallback-version = "0.5.0"
|
|
11
|
+
|
|
12
|
+
[tool.hatch.version.raw-options]
|
|
13
|
+
search_parent_directories = true
|
|
14
|
+
# Parse tags of the form: <package-name>@v<semver>
|
|
15
|
+
tag_regex = "^torch-2dtm@v(?P<version>\\d+\\.\\d+\\.\\d+.*)$"
|
|
16
|
+
# Constrain git-describe so it only considers TeamTomo's own tags, not other workspace tags.
|
|
17
|
+
# See https://github.com/ofek/hatch-vcs/issues/71
|
|
18
|
+
git_describe_command = "git describe --dirty --tags --long --match 'torch-2dtm@v[0-9]*.[0-9]*.[0-9]*'"
|
|
19
|
+
|
|
20
|
+
# read more about configuring hatch at:
|
|
21
|
+
# https://hatch.pypa.io/latest/config/build/
|
|
22
|
+
[tool.hatch.build.targets.wheel]
|
|
23
|
+
only-include = ["src"]
|
|
24
|
+
sources = ["src"]
|
|
25
|
+
|
|
26
|
+
# https://peps.python.org/pep-0621/
|
|
27
|
+
[project]
|
|
28
|
+
name = "torch-2dtm"
|
|
29
|
+
dynamic = ["version"]
|
|
30
|
+
description = "2D template matching in pytorch"
|
|
31
|
+
readme = "README.md"
|
|
32
|
+
requires-python = ">=3.11"
|
|
33
|
+
license = { text = "BSD-3-Clause" }
|
|
34
|
+
authors = [
|
|
35
|
+
{ name = "Josh Dickerson", email = "jdickerson@berkeley.edu" },
|
|
36
|
+
{ name = "Matthew Giammar", email = "matthew_giammar@berkeley.edu" },
|
|
37
|
+
{ name = "Alister Burt", email = "alisterburt@gmail.com" },
|
|
38
|
+
]
|
|
39
|
+
# https://pypi.org/classifiers/
|
|
40
|
+
classifiers = [
|
|
41
|
+
"Development Status :: 3 - Alpha",
|
|
42
|
+
"License :: OSI Approved :: BSD License",
|
|
43
|
+
"Programming Language :: Python :: 3",
|
|
44
|
+
"Programming Language :: Python :: 3.11",
|
|
45
|
+
"Programming Language :: Python :: 3.12",
|
|
46
|
+
"Programming Language :: Python :: 3.13",
|
|
47
|
+
"Programming Language :: Python :: 3.14",
|
|
48
|
+
"Typing :: Typed",
|
|
49
|
+
]
|
|
50
|
+
# add your package dependencies here
|
|
51
|
+
dependencies = [
|
|
52
|
+
"torch",
|
|
53
|
+
"einops",
|
|
54
|
+
"torch-fourier-slice",
|
|
55
|
+
"setuptools", # needed for torch compilation? lol
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# https://peps.python.org/pep-0621/#dependencies-optional-dependencies
|
|
60
|
+
# "extras" (e.g. for `pip install .[test]`)
|
|
61
|
+
[dependency-groups]
|
|
62
|
+
# add dependencies used for testing here
|
|
63
|
+
test = ["pytest", "pytest-cov", "scipy"]
|
|
64
|
+
# add anything else you like to have in your dev environment here
|
|
65
|
+
dev = [
|
|
66
|
+
{ include-group = "test" },
|
|
67
|
+
"ipython",
|
|
68
|
+
"pdbpp", # https://github.com/pdbpp/pdbpp
|
|
69
|
+
"rich", # https://github.com/Textualize/rich
|
|
70
|
+
]
|
|
71
|
+
|
|
72
|
+
[project.urls]
|
|
73
|
+
homepage = "https://github.com/teamtomo/teamtomo"
|
|
74
|
+
repository = "https://github.com/teamtomo/teamtomo"
|
|
75
|
+
|
|
76
|
+
# Entry points
|
|
77
|
+
# https://peps.python.org/pep-0621/#entry-points
|
|
78
|
+
# same as console_scripts entry point
|
|
79
|
+
# [project.scripts]
|
|
80
|
+
# torch-2dtm-cli = "torch_2dtm:main_cli"
|
|
81
|
+
|
|
82
|
+
# [project.entry-points."some.group"]
|
|
83
|
+
# tomatoes = "torch_2dtm:main_tomatoes"
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
# https://docs.pytest.org/
|
|
87
|
+
[tool.pytest.ini_options]
|
|
88
|
+
minversion = "7.0"
|
|
89
|
+
testpaths = ["tests"]
|
|
90
|
+
filterwarnings = ["error"]
|
|
91
|
+
|
|
92
|
+
# https://coverage.readthedocs.io/
|
|
93
|
+
[tool.coverage.report]
|
|
94
|
+
show_missing = true
|
|
95
|
+
exclude_lines = [
|
|
96
|
+
"pragma: no cover",
|
|
97
|
+
"if TYPE_CHECKING:",
|
|
98
|
+
"@overload",
|
|
99
|
+
"except ImportError",
|
|
100
|
+
"\\.\\.\\.",
|
|
101
|
+
"raise NotImplementedError()",
|
|
102
|
+
"pass",
|
|
103
|
+
]
|
|
104
|
+
|
|
105
|
+
[tool.coverage.run]
|
|
106
|
+
source = ["torch_2dtm"]
|
|
107
|
+
|
|
108
|
+
# https://github.com/mgedmin/check-manifest#configuration
|
|
109
|
+
# add files that you want check-manifest to explicitly ignore here
|
|
110
|
+
# (files that are in the repo but shouldn't go in the package)
|
|
111
|
+
[tool.check-manifest]
|
|
112
|
+
ignore = [".pre-commit-config.yaml", ".ruff_cache/**/*", "tests/**/*"]
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""Cross-correlation functions."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import platform
|
|
5
|
+
import einops
|
|
6
|
+
from einops._torch_specific import allow_ops_in_compiled_graph
|
|
7
|
+
from torch_fourier_slice import extract_central_slices_rfft_3d
|
|
8
|
+
|
|
9
|
+
from torch_2dtm.utils import normalize_template_projection
|
|
10
|
+
|
|
11
|
+
# compile normalization utility function
|
|
12
|
+
allow_ops_in_compiled_graph()
|
|
13
|
+
if platform.system() == "Linux":
|
|
14
|
+
COMPILE_BACKEND = "aot_eager" # More stable than inductor on Linux
|
|
15
|
+
else:
|
|
16
|
+
COMPILE_BACKEND = "inductor" # inductor for macOS
|
|
17
|
+
|
|
18
|
+
normalize_template_projection_compiled = torch.compile(
|
|
19
|
+
normalize_template_projection, backend=COMPILE_BACKEND
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def match_template_dft_2d(
|
|
24
|
+
image_dft: torch.Tensor,
|
|
25
|
+
template_dft: torch.Tensor,
|
|
26
|
+
rotation_matrices: torch.Tensor,
|
|
27
|
+
filters: torch.Tensor,
|
|
28
|
+
) -> torch.Tensor:
|
|
29
|
+
"""Batched projection and cross-correlation with a set of filters.
|
|
30
|
+
|
|
31
|
+
Note that this function returns a cross-correlation image which is the
|
|
32
|
+
same size as the input image prior to FFT calculation.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
image_dft : torch.Tensor
|
|
37
|
+
`(h_im, w_im // 2 + 1)` fourier transform (rfft) of the real space image.
|
|
38
|
+
Any filters etc are assumed to have already been applied to this image.
|
|
39
|
+
template_dft : torch.Tensor
|
|
40
|
+
`(d, h, w // 2 + 1)` fftshifted fourier transform (rfft) of the real valued template volume to take Fourier
|
|
41
|
+
slices from.
|
|
42
|
+
rotation_matrices : torch.Tensor
|
|
43
|
+
`(b, 3, 3)` batched rotation matrices to rotate slices sampled from the template fourier transform.
|
|
44
|
+
filters : torch.Tensor
|
|
45
|
+
`(..., h, w // 2 + 1)` filters applied to FFT slices which are fftshifted results of a rfft.
|
|
46
|
+
|
|
47
|
+
Returns
|
|
48
|
+
-------
|
|
49
|
+
torch.Tensor
|
|
50
|
+
Cross-correlation of the image with the template volume for each
|
|
51
|
+
orientation and defocus value. Will have shape
|
|
52
|
+
(orientations, defocus_batch, H, W).
|
|
53
|
+
"""
|
|
54
|
+
# Grab relevant dimensions
|
|
55
|
+
_, h, w = template_dft.shape
|
|
56
|
+
h_im, w_im = image_dft.shape
|
|
57
|
+
w_im = 2 * (w_im - 1)
|
|
58
|
+
w = 2 * (w - 1)
|
|
59
|
+
|
|
60
|
+
# Extract central slice(s) from the template volume
|
|
61
|
+
fourier_slices = extract_central_slices_rfft_3d(
|
|
62
|
+
volume_rfft=template_dft,
|
|
63
|
+
rotation_matrices=rotation_matrices,
|
|
64
|
+
) # (b, h, w)
|
|
65
|
+
fourier_slices = torch.fft.ifftshift(fourier_slices, dim=(-2,))
|
|
66
|
+
fourier_slices[..., 0, 0] = 0 + 0j # zero out the DC component (mean zero)
|
|
67
|
+
fourier_slices *= -1 # flip contrast
|
|
68
|
+
|
|
69
|
+
# Apply the projective filters with broadcasting
|
|
70
|
+
filters = einops.rearrange(filters, "... h w -> ... 1 h w")
|
|
71
|
+
fourier_slices = fourier_slices * filters # (..., b, h, w)
|
|
72
|
+
|
|
73
|
+
# Inverse Fourier transform into real space and normalize
|
|
74
|
+
projections = torch.fft.irfftn(fourier_slices, dim=(-2, -1))
|
|
75
|
+
projections = torch.fft.ifftshift(projections, dim=(-2, -1))
|
|
76
|
+
projections = normalize_template_projection_compiled(
|
|
77
|
+
projections, (h, w), (h_im, w_im)
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Padded forward Fourier transform for cross-correlation
|
|
81
|
+
projections_dft = torch.fft.rfftn(projections, dim=(-2, -1), s=(h_im, w_im))
|
|
82
|
+
|
|
83
|
+
# Zero the DC component (set mean zero)
|
|
84
|
+
projections_dft[..., 0, 0] = 0 + 0j
|
|
85
|
+
|
|
86
|
+
# Cross correlation step by element-wise multiplication
|
|
87
|
+
projections_dft = image_dft * torch.conj(projections_dft)
|
|
88
|
+
cross_correlation = torch.fft.irfftn(projections_dft, dim=(-2, -1))
|
|
89
|
+
|
|
90
|
+
return cross_correlation # (..., h_im, w_im)
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""Utility functions associated with backend functions."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import einops
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def normalize_template_projection(
|
|
8
|
+
projections: torch.Tensor,
|
|
9
|
+
small_shape: tuple[int, int],
|
|
10
|
+
large_shape: tuple[int, int],
|
|
11
|
+
) -> torch.Tensor:
|
|
12
|
+
r"""Subtract mean of edge values and set variance to 1 (in large shape).
|
|
13
|
+
|
|
14
|
+
This function uses the fact that variance of a sequence, Var(X), is scaled by the
|
|
15
|
+
relative size of the small (unpadded) and large (padded with zeros) space. Some
|
|
16
|
+
negligible error is introduced into the variance (~1e-4) due to this routine.
|
|
17
|
+
|
|
18
|
+
Let $X$ be the large, zero-padded projection and $x$ the small projection each
|
|
19
|
+
with sizes $(H, W)$ and $(h, w)$, respectively. The mean of the zero-padded
|
|
20
|
+
projection in terms of the small projection is:
|
|
21
|
+
.. math::
|
|
22
|
+
\begin{align}
|
|
23
|
+
\mu(X) &= \frac{1}{H \cdot W} \sum_{i=1}^{H} \sum_{j=1}^{W} X_{ij} \\
|
|
24
|
+
\mu(X) &= \frac{1}{H \cdot W} \sum_{i=1}^{h} \sum_{j=1}^{w} X_{ij} + 0 \\
|
|
25
|
+
\mu(X) &= \frac{h \cdot w}{H \cdot W} \mu(x)
|
|
26
|
+
\end{align}
|
|
27
|
+
The variance of the zero-padded projection in terms of the small projection can be
|
|
28
|
+
obtained by:
|
|
29
|
+
.. math::
|
|
30
|
+
\begin{align}
|
|
31
|
+
Var(X) &= \frac{1}{H \cdot W} \sum_{i=1}^{H} \sum_{j=1}^{W} (X_{ij} -
|
|
32
|
+
\mu(X))^2 \\
|
|
33
|
+
Var(X) &= \frac{1}{H \cdot W} \left(\sum_{i=1}^{h}
|
|
34
|
+
\sum_{j=1}^{w} (X_{ij} - \mu(X))^2 +
|
|
35
|
+
\sum_{i=h+1}^{H}\sum_{i=w+1}^{W} \mu(X)^2 \right) \\
|
|
36
|
+
Var(X) &= \frac{1}{H \cdot W} \sum_{i=1}^{h} \sum_{j=1}^{w} (X_{ij} -
|
|
37
|
+
\mu(X))^2 + (H-h)(W-w)\mu(X)^2
|
|
38
|
+
\end{align}
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
projections : torch.Tensor
|
|
43
|
+
`(..., h, w)` real-space projections of the template (in small space).
|
|
44
|
+
small_shape : tuple[int, int]
|
|
45
|
+
`(h, w)` shape of the template (in real space).
|
|
46
|
+
large_shape : tuple[int, int]
|
|
47
|
+
`(h_im, w_im)` shape of the image (in real space).
|
|
48
|
+
|
|
49
|
+
Returns
|
|
50
|
+
-------
|
|
51
|
+
projections: torch.Tensor
|
|
52
|
+
`(..., h, w)` edge-mean subtracted projections
|
|
53
|
+
normalized so variance of zero-padded projection would be 1.
|
|
54
|
+
"""
|
|
55
|
+
h, w = small_shape
|
|
56
|
+
h_im, w_im = large_shape
|
|
57
|
+
|
|
58
|
+
# Extract edges while preserving batch dimensions
|
|
59
|
+
top_edge = projections[..., 0, :] # shape: (..., w)
|
|
60
|
+
bottom_edge = projections[..., -1, :] # shape: (..., w)
|
|
61
|
+
left_edge = projections[..., 1:-1, 0] # shape: (..., h-2)
|
|
62
|
+
right_edge = projections[..., 1:-1, -1] # shape: (..., h-2)
|
|
63
|
+
edge_pixels = torch.concatenate(
|
|
64
|
+
[top_edge, bottom_edge, left_edge, right_edge], dim=-1
|
|
65
|
+
) # shape: (..., w + w + h-2 + h-2)
|
|
66
|
+
|
|
67
|
+
# Subtract the edge pixel mean and calculate variance of small, unpadded projection
|
|
68
|
+
edge_mean = einops.reduce(edge_pixels, "... b -> ...", reduction="mean")
|
|
69
|
+
edge_mean = einops.rearrange(edge_mean, "... -> ... 1 1")
|
|
70
|
+
projections -= edge_mean
|
|
71
|
+
|
|
72
|
+
# Fast calculation of mean/var using Torch + appropriate scaling.
|
|
73
|
+
relative_size = h * w / (h_im * w_im)
|
|
74
|
+
per_image_mean = einops.reduce(projections, "... h w -> ...", reduction="mean")
|
|
75
|
+
per_image_mean *= relative_size**2
|
|
76
|
+
|
|
77
|
+
# First term of the variance calculation
|
|
78
|
+
diff = projections - einops.rearrange(per_image_mean, "... -> ... 1 1")
|
|
79
|
+
per_image_variance = einops.reduce(
|
|
80
|
+
diff**2, pattern="... h w -> ...", reduction="sum"
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# Add the second term of the variance calculation
|
|
84
|
+
per_image_variance += (h_im - h) * (w_im - w) * per_image_mean**2
|
|
85
|
+
per_image_variance /= h_im * w_im
|
|
86
|
+
per_image_variance = einops.rearrange(per_image_variance, "... -> ... 1 1")
|
|
87
|
+
|
|
88
|
+
return projections / torch.sqrt(per_image_variance)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch_2dtm
|
|
3
|
+
|
|
4
|
+
from scipy.stats import special_ortho_group
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def test_template_match_dft_2d():
|
|
8
|
+
# Create random test data
|
|
9
|
+
# 1. Create a random image and compute its FFT
|
|
10
|
+
image_size = (128, 128)
|
|
11
|
+
image = torch.randn(*image_size, dtype=torch.float32)
|
|
12
|
+
image_dft = torch.fft.rfftn(image, dim=(0, 1)) # Shape: (128, 65)
|
|
13
|
+
|
|
14
|
+
# 2. Create a random 3D template and compute its FFT
|
|
15
|
+
template_size = (64, 64, 64)
|
|
16
|
+
template = torch.randn(*template_size, dtype=torch.float32)
|
|
17
|
+
template_dft = torch.fft.rfftn(template, dim=(0, 1, 2)) # Shape: (64, 64, 33)
|
|
18
|
+
|
|
19
|
+
# 3. Create a batch of random rotation matrices with shape (b, 3, 3)
|
|
20
|
+
num_orientations = 10
|
|
21
|
+
rotation_matrices = torch.tensor(
|
|
22
|
+
special_ortho_group.rvs(size=num_orientations, dim=3), dtype=torch.float32
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
# 4. Create an arbitrary stack of Fourier space filters (identity filter in this example)
|
|
26
|
+
# These filters operate on rffts of the 2D projection images
|
|
27
|
+
# Filter shape: (..., h, w // 2 + 1)
|
|
28
|
+
filters_shape = (5, 4, 3, template_size[0], template_size[1] // 2 + 1)
|
|
29
|
+
filters = torch.ones(filters_shape, dtype=torch.complex64)
|
|
30
|
+
|
|
31
|
+
# Perform template matching
|
|
32
|
+
cross_correlation = torch_2dtm.match_template_dft_2d(
|
|
33
|
+
image_dft=image_dft,
|
|
34
|
+
template_dft=template_dft,
|
|
35
|
+
rotation_matrices=rotation_matrices,
|
|
36
|
+
filters=filters,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
# correct output shape is (..., num_orientations, h, w)
|
|
40
|
+
assert cross_correlation.shape == (5, 4, 3, num_orientations, *image_size)
|